diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index da24d7bc69..637d5e442a 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -67,6 +67,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + - name: Install libolm + run: sudo apt-get install libolm-dev libolm3 - name: Install Go uses: actions/setup-go@v3 with: @@ -101,6 +103,8 @@ jobs: --health-retries 5 steps: - uses: actions/checkout@v3 + - name: Install libolm + run: sudo apt-get install libolm-dev libolm3 - name: Setup go uses: actions/setup-go@v3 with: @@ -232,6 +236,8 @@ jobs: --health-retries 5 steps: - uses: actions/checkout@v3 + - name: Install libolm + run: sudo apt-get install libolm-dev libolm3 - name: Setup go uses: actions/setup-go@v3 with: diff --git a/.github/workflows/helm.yml b/.github/workflows/helm.yml index a9c1718a04..bf62a1c199 100644 --- a/.github/workflows/helm.yml +++ b/.github/workflows/helm.yml @@ -38,3 +38,4 @@ jobs: with: config: helm/cr.yaml charts_dir: helm/ + mark_as_latest: false diff --git a/.gitignore b/.gitignore index 4fc0f935a2..515c09db87 100644 --- a/.gitignore +++ b/.gitignore @@ -77,7 +77,10 @@ docs/_site media_store/ -__debug_bin +# golang workspaces +go.work* + +__debug_bin* cmd/dendrite-monolith-server/dendrite-monolith-server build diff --git a/CHANGES.md b/CHANGES.md index 8052efd8a1..c99ed2255b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,59 @@ # Changelog +## Dendrite 0.13.1 (2023-07-06) + +This releases fixes a long-standing "off-by-one" error which could result in state resets. Upgrading to this version is **highly** recommended. + +When deduplicating state events, we were checking if the event in question was already in a state snapshot. If it was in a previous state snapshot, we would +then remove it from the list of events to store. If this happened, we were, unfortunately, skipping the next event to check. This resulted in +events getting stored in state snapshots where they may not be needed. When we now compared two of those state snapshots, one of them +contained the skipped event, while the other didn't. This difference possibly shouldn't exist, resulting in unexpected state resets and explains +reports of missing state events as well. + +Rooms where a state reset occurred earlier should, hopefully, reconcile over time. + +### Fixes: + +- A long-standing "off-by-one" error has been fixed, which could result in state resets +- Roomserver Prometheus Metrics are available again + +### Features + +- Updated dependencies + - Internal NATS Server has been updated from v2.9.15 to v2.9.19 + +## Dendrite 0.13.0 (2023-06-30) + +### Features + +- Results in responses to `/search` now highlight words more accurately and not only the search terms as before +- Support for connecting to appservices listening on unix sockets has been added (contributed by [cyberb](https://github.com/cyberb)) +- Admin APIs for token authenticated registration have been added (contributed by [santhoshivan23](https://github.com/santhoshivan23)) +- Initial support for [MSC4014: Pseudonymous Identities](https://github.com/matrix-org/matrix-spec-proposals/blob/kegan/pseudo-ids/proposals/4014-pseudonymous-identities.md) + - This is **highly experimental**, things like changing usernames/avatars, inviting users, upgrading rooms isn't working + +### Fixes + +- `m.upload.size` is now optional, finally allowing uploads with unlimited file size +- A bug while resolving server names has been fixed (contributed by [anton-molyboha](https://github.com/anton-molyboha)) +- Application services should only receive one invitation instead of 2 (or worse), which could result in state resets previously +- Several admin endpoints are now using `POST` instead of `GET` +- `/delete_devices` now uses user-interactive authentication +- Several "membership" (e.g `/kick`, `/ban`) endpoints are using less heavy database queries to check if the user is allowed to perform this action +- `/3pid` endpoints are now available on `/v3` instead of the `/unstable` prefix +- Upgrading rooms ignores state events of other users, which could result in failed upgrades before +- Uploading key backups with a wrong version now returns `M_WRONG_ROOM_KEYS_VERSION` +- A potential state reset when joining the same room multiple times in short sequence has been fixed +- A bug where we returned the full event as `redacted_because` in redaction events has been fixed +- The `displayname` and `avatar_url` can now be set to empty strings +- Unsafe hotserving of files has been fixed (contributed by [joshqou](https://github.com/joshqou)) +- Joining new rooms would potentially return "redacted" events, due to history visibility not being set correctly, this could result in events being rejected +- Backfilling resulting in `unsuported room version ''` should now be solved + +### Other + +- Huge refactoring of Dendrite and gomatrixserverlib + ## Dendrite 0.12.0 (2023-03-13) ### Features diff --git a/README.md b/README.md index 295203eb48..34604eff93 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ It intends to provide an **efficient**, **reliable** and **scalable** alternativ Dendrite is **beta** software, which means: -- Dendrite is ready for early adopters. We recommend running in Monolith mode with a PostgreSQL database. +- Dendrite is ready for early adopters. We recommend running Dendrite with a PostgreSQL database. - Dendrite has periodic releases. We intend to release new versions as we fix bugs and land significant features. - Dendrite supports database schema upgrades between releases. This means you should never lose your messages when upgrading Dendrite. @@ -21,7 +21,7 @@ This does not mean: - Dendrite is bug-free. It has not yet been battle-tested in the real world and so will be error prone initially. - Dendrite is feature-complete. There may be client or federation APIs that are not implemented. -- Dendrite is ready for massive homeserver deployments. There is no sharding of microservices (although it is possible to run them on separate machines) and there is no high-availability/clustering support. +- Dendrite is ready for massive homeserver deployments. There is no high-availability/clustering support. Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices. @@ -47,7 +47,7 @@ For a usable federating Dendrite deployment, you will also need: Also recommended are: - A PostgreSQL database engine, which will perform better than SQLite with many users and/or larger rooms -- A reverse proxy server, such as nginx, configured [like this sample](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) +- A reverse proxy server, such as nginx, configured [like this sample](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/dendrite-sample.conf) The [Federation Tester](https://federationtester.matrix.org) can be used to verify your deployment. @@ -60,7 +60,7 @@ The following instructions are enough to get Dendrite started as a non-federatin ```bash $ git clone https://github.com/matrix-org/dendrite $ cd dendrite -$ ./build.sh +$ go build -o bin/ ./cmd/... # Generate a Matrix signing key for federation (required) $ ./bin/generate-keys --private-key matrix_key.pem @@ -85,7 +85,7 @@ Then point your favourite Matrix client at `http://localhost:8008` or `https://l ## Progress -We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver +We use a script called "Are We Synapse Yet" which checks Sytest compliance rates. Sytest is a black-box homeserver test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it updates with CI. As of January 2023, we have 100% server-server parity with Synapse, and the client-server parity is at 93% , though check CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse diff --git a/appservice/appservice.go b/appservice/appservice.go index 1f6037ee2e..d94a483e0c 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -20,10 +20,9 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" - "github.com/matrix-org/gomatrixserverlib" - appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/query" @@ -86,7 +85,7 @@ func NewInternalAPI( func generateAppServiceAccount( userAPI userapi.AppserviceUserAPI, as config.ApplicationService, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 282c631285..878ca5666e 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -27,7 +27,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/test/testrig" ) @@ -326,7 +326,7 @@ func TestRoomserverConsumerOneInvite(t *testing.T) { room := test.NewRoom(t, alice) // Invite Bob - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) @@ -345,7 +345,7 @@ func TestRoomserverConsumerOneInvite(t *testing.T) { t.Fatal(err) } for _, ev := range txn.Events { - if ev.Type != gomatrixserverlib.MRoomMember { + if ev.Type != spec.MRoomMember { continue } // Usually we would check the event content for the membership, but since diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 586ca33a87..1877de37af 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -26,9 +26,11 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -103,7 +105,7 @@ func (s *OutputRoomEventConsumer) onMessage( ctx context.Context, state *appserviceState, msgs []*nats.Msg, ) bool { log.WithField("appservice", state.ID).Tracef("Appservice worker received %d message(s) from roomserver", len(msgs)) - events := make([]*gomatrixserverlib.HeaderedEvent, 0, len(msgs)) + events := make([]*types.HeaderedEvent, 0, len(msgs)) for _, msg := range msgs { // Only handle events we care about receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType)) @@ -173,13 +175,15 @@ func (s *OutputRoomEventConsumer) onMessage( // endpoint. It will block for the backoff period if necessary. func (s *OutputRoomEventConsumer) sendEvents( ctx context.Context, state *appserviceState, - events []*gomatrixserverlib.HeaderedEvent, + events []*types.HeaderedEvent, txnID string, ) error { // Create the transaction body. transaction, err := json.Marshal( ApplicationServiceTransaction{ - Events: synctypes.HeaderedToClientEvents(events, synctypes.FormatAll), + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), }, ) if err != nil { @@ -188,7 +192,7 @@ func (s *OutputRoomEventConsumer) sendEvents( // If txnID is not defined, generate one from the events. if txnID == "" { - txnID = fmt.Sprintf("%d_%d", events[0].Event.OriginServerTS(), len(transaction)) + txnID = fmt.Sprintf("%d_%d", events[0].PDU.OriginServerTS(), len(transaction)) } // Send the transaction to the appservice. @@ -230,17 +234,27 @@ func (s *appserviceState) backoffAndPause(err error) error { // event falls within one of a given application service's namespaces. // // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 -func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool { +func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { + user := "" + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return false + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + if err == nil { + user = userID.String() + } + switch { case appservice.URL == "": return false - case appservice.IsInterestedInUserID(event.Sender()): + case appservice.IsInterestedInUserID(user): return true case appservice.IsInterestedInRoomID(event.RoomID()): return true } - if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil { + if event.Type() == spec.MRoomMember && event.StateKey() != nil { if appservice.IsInterestedInUserID(*event.StateKey()) { return true } @@ -268,7 +282,7 @@ func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Cont // appserviceJoinedAtEvent returns a boolean depending on whether a given // appservice has membership at the time a given event was created. -func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice *config.ApplicationService) bool { +func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { // TODO: This is only checking the current room state, not the state at // the event in question. Pretty sure this is what Synapse does too, but // until we have a lighter way of checking the state before the event that @@ -286,7 +300,7 @@ func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, e switch { case ev.StateKey == nil: continue - case ev.Type != gomatrixserverlib.MRoomMember: + case ev.Type != spec.MRoomMember: continue } var membership gomatrixserverlib.MemberContent @@ -294,7 +308,7 @@ func (s *OutputRoomEventConsumer) appserviceJoinedAtEvent(ctx context.Context, e switch { case err != nil: continue - case membership.Membership == gomatrixserverlib.Join: + case membership.Membership == spec.Join: if appservice.IsInterestedInUserID(*ev.StateKey) { return true } diff --git a/build.cmd b/build.cmd deleted file mode 100644 index 9e90622c8b..0000000000 --- a/build.cmd +++ /dev/null @@ -1,51 +0,0 @@ -@echo off - -:ENTRY_POINT - setlocal EnableDelayedExpansion - - REM script base dir - set SCRIPTDIR=%~dp0 - set PROJDIR=%SCRIPTDIR:~0,-1% - - REM Put installed packages into ./bin - set GOBIN=%PROJDIR%\bin - - set FLAGS= - - REM Check if sources are under Git control - if not exist ".git" goto :CHECK_BIN - - REM set BUILD=`git rev-parse --short HEAD \\ ""` - FOR /F "tokens=*" %%X IN ('git rev-parse --short HEAD') DO ( - set BUILD=%%X - ) - - REM set BRANCH=`(git symbolic-ref --short HEAD \ tr -d \/ ) \\ ""` - FOR /F "tokens=*" %%X IN ('git symbolic-ref --short HEAD') DO ( - set BRANCHRAW=%%X - set BRANCH=!BRANCHRAW:/=! - ) - if "%BRANCH%" == "main" set BRANCH= - - set FLAGS=-X github.com/matrix-org/dendrite/internal.branch=%BRANCH% -X github.com/matrix-org/dendrite/internal.build=%BUILD% - -:CHECK_BIN - if exist "bin" goto :ALL_SET - mkdir "bin" - -:ALL_SET - set CGO_ENABLED=1 - for /D %%P in (cmd\*) do ( - go build -trimpath -ldflags "%FLAGS%" -v -o ".\bin" ".\%%P" - ) - - set CGO_ENABLED=0 - set GOOS=js - set GOARCH=wasm - go build -trimpath -ldflags "%FLAGS%" -o bin\main.wasm .\cmd\dendritejs-pinecone - - goto :DONE - -:DONE - echo Done - endlocal \ No newline at end of file diff --git a/build.sh b/build.sh deleted file mode 100755 index f8b5001bff..0000000000 --- a/build.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/sh -eu - -# Put installed packages into ./bin -export GOBIN=$PWD/`dirname $0`/bin - -if [ -d ".git" ] -then - export BUILD=`git rev-parse --short HEAD || ""` - export BRANCH=`(git symbolic-ref --short HEAD | tr -d \/ ) || ""` - if [ "$BRANCH" = main ] - then - export BRANCH="" - fi - - export FLAGS="-X github.com/matrix-org/dendrite/internal.branch=$BRANCH -X github.com/matrix-org/dendrite/internal.build=$BUILD" -else - export FLAGS="" -fi - -mkdir -p bin - -CGO_ENABLED=1 go build -trimpath -ldflags "$FLAGS" -v -o "bin/" ./cmd/... - -# CGO_ENABLED=0 GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o bin/main.wasm ./cmd/dendritejs-pinecone diff --git a/build/docker/README.md b/build/docker/README.md index b66cb864b1..8d69b9af16 100644 --- a/build/docker/README.md +++ b/build/docker/README.md @@ -6,23 +6,20 @@ They can be found on Docker Hub: - [matrixdotorg/dendrite-monolith](https://hub.docker.com/r/matrixdotorg/dendrite-monolith) for monolith deployments -## Dockerfiles +## Dockerfile -The `Dockerfile` is a multistage file which can build all four Dendrite -images depending on the supplied `--target`. From the root of the Dendrite +The `Dockerfile` is a multistage file which can build Dendrite. From the root of the Dendrite repository, run: ``` -docker build . --target monolith -t matrixdotorg/dendrite-monolith -docker build . --target demo-pinecone -t matrixdotorg/dendrite-demo-pinecone -docker build . --target demo-yggdrasil -t matrixdotorg/dendrite-demo-yggdrasil +docker build . -t matrixdotorg/dendrite-monolith ``` -## Compose files +## Compose file -There are two sample `docker-compose` files: +There is one sample `docker-compose` files: -- `docker-compose.monolith.yml` which runs a monolith Dendrite deployment +- `docker-compose.yml` which runs a Dendrite deployment with Postgres ## Configuration @@ -55,7 +52,7 @@ Create your config based on the [`dendrite-sample.yaml`](https://github.com/matr Then start the deployment: ``` -docker-compose -f docker-compose.monolith.yml up +docker-compose -f docker-compose.yml up ``` ## Building the images diff --git a/build/docker/docker-compose.monolith.yml b/build/docker/docker-compose.monolith.yml deleted file mode 100644 index 1a8fe4eee4..0000000000 --- a/build/docker/docker-compose.monolith.yml +++ /dev/null @@ -1,44 +0,0 @@ -version: "3.4" -services: - postgres: - hostname: postgres - image: postgres:14 - restart: always - volumes: - - ./postgres/create_db.sh:/docker-entrypoint-initdb.d/20-create_db.sh - # To persist your PostgreSQL databases outside of the Docker image, - # to prevent data loss, modify the following ./path_to path: - - ./path_to/postgresql:/var/lib/postgresql/data - environment: - POSTGRES_PASSWORD: itsasecret - POSTGRES_USER: dendrite - healthcheck: - test: ["CMD-SHELL", "pg_isready -U dendrite"] - interval: 5s - timeout: 5s - retries: 5 - networks: - - internal - - monolith: - hostname: monolith - image: matrixdotorg/dendrite-monolith:latest - command: [ - "--tls-cert=server.crt", - "--tls-key=server.key" - ] - ports: - - 8008:8008 - - 8448:8448 - volumes: - - ./config:/etc/dendrite - - ./media:/var/dendrite/media - depends_on: - - postgres - networks: - - internal - restart: unless-stopped - -networks: - internal: - attachable: true diff --git a/build/docker/docker-compose.yml b/build/docker/docker-compose.yml new file mode 100644 index 0000000000..9397673f85 --- /dev/null +++ b/build/docker/docker-compose.yml @@ -0,0 +1,52 @@ +version: "3.4" + +services: + postgres: + hostname: postgres + image: postgres:15-alpine + restart: always + volumes: + # This will create a docker volume to persist the database files in. + # If you prefer those files to be outside of docker, you'll need to change this. + - dendrite_postgres_data:/var/lib/postgresql/data + environment: + POSTGRES_PASSWORD: itsasecret + POSTGRES_USER: dendrite + POSTGRES_DATABASE: dendrite + healthcheck: + test: ["CMD-SHELL", "pg_isready -U dendrite"] + interval: 5s + timeout: 5s + retries: 5 + networks: + - internal + + monolith: + hostname: monolith + image: matrixdotorg/dendrite-monolith:latest + ports: + - 8008:8008 + - 8448:8448 + volumes: + - ./config:/etc/dendrite + # The following volumes use docker volumes, change this + # if you prefer to have those files outside of docker. + - dendrite_media:/var/dendrite/media + - dendrite_jetstream:/var/dendrite/jetstream + - dendrite_search_index:/var/dendrite/searchindex + depends_on: + postgres: + condition: service_healthy + networks: + - internal + restart: unless-stopped + +networks: + internal: + attachable: true + +volumes: + dendrite_postgres_data: + dendrite_media: + dendrite_jetstream: + dendrite_search_index: \ No newline at end of file diff --git a/build/docker/postgres/create_db.sh b/build/docker/postgres/create_db.sh deleted file mode 100755 index 27d2a4df43..0000000000 --- a/build/docker/postgres/create_db.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/sh - -for db in userapi_accounts mediaapi syncapi roomserver keyserver federationapi appservice mscs; do - createdb -U dendrite -O dendrite dendrite_$db -done diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 7ce1892c91..720ce37eb0 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -33,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" _ "golang.org/x/mobile/bind" @@ -134,7 +135,7 @@ func (m *DendriteMonolith) Start() { Generate: true, SingleDatabase: true, }) - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + cfg.Global.ServerName = spec.ServerName(hex.EncodeToString(pk)) cfg.Global.PrivateKey = sk cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory)) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 41dbb189e6..9d2acd68ed 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -2,13 +2,13 @@ package clientapi import ( "context" + "fmt" "net/http" "net/http/httptest" "reflect" "testing" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" @@ -19,17 +19,654 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/tidwall/gjson" + capi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" uapi "github.com/matrix-org/dendrite/userapi/api" ) +func TestAdminCreateToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + requestOpt test.HTTPRequestOpt + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token1", + }, + ), + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token2", + }, + ), + }, + { + name: "Alice can create a token without specifyiing any information", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{}), + }, + { + name: "Alice can to create a token specifying a name", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token3", + }, + ), + }, + { + name: "Alice cannot to create a token that already exists", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token3", + }, + ), + }, + { + name: "Alice can create a token specifying valid params", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token4", + "uses_allowed": 5, + "expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice cannot create a token specifying invalid name", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token@", + }, + ), + }, + { + name: "Alice cannot create a token specifying invalid uses_allowed", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token5", + "uses_allowed": -1, + }, + ), + }, + { + name: "Alice cannot create a token specifying invalid expiry_time", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token6", + "expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice cannot to create a token specifying invalid length", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "length": 80, + }, + ), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new") + if tc.requestOpt != nil { + req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new", tc.requestOpt) + } + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminListRegistrationTokens(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + tokens := []capi.RegistrationToken{ + { + Token: getPointer("valid"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("invalid"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + valid string + isValidSpecified bool + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + isValidSpecified: false, + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + isValidSpecified: false, + }, + { + name: "Alice can list all tokens", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + }, + { + name: "Alice can list all valid tokens", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + valid: "true", + isValidSpecified: true, + }, + { + name: "Alice can list all invalid tokens", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + valid: "false", + isValidSpecified: true, + }, + { + name: "No response when valid has a bad value", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + valid: "trueee", + isValidSpecified: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + var path string + if tc.isValidSpecified { + path = fmt.Sprintf("/_dendrite/admin/registrationTokens?valid=%v", tc.valid) + } else { + path = "/_dendrite/admin/registrationTokens" + } + req := test.NewRequest(t, http.MethodGet, path) + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminGetRegistrationToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + tokens := []capi.RegistrationToken{ + { + Token: getPointer("alice_token1"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("alice_token2"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + token string + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + }, + { + name: "Alice can GET alice_token1", + token: "alice_token1", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + }, + { + name: "Alice can GET alice_token2", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token2", + }, + { + name: "Alice cannot GET a token that does not exists", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token3", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token) + req := test.NewRequest(t, http.MethodGet, path) + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminDeleteRegistrationToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + tokens := []capi.RegistrationToken{ + { + Token: getPointer("alice_token1"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("alice_token2"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + token string + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + }, + { + name: "Alice can DELETE alice_token1", + token: "alice_token1", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + }, + { + name: "Alice can DELETE alice_token2", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token2", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token) + req := test.NewRequest(t, http.MethodDelete, path) + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminUpdateRegistrationToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + tokens := []capi.RegistrationToken{ + { + Token: getPointer("alice_token1"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("alice_token2"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + testCases := []struct { + name string + requestingUser *test.User + method string + token string + requestOpt test.HTTPRequestOpt + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 10, + }, + ), + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 10, + }, + ), + }, + { + name: "Alice can UPDATE a token's uses_allowed property", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 10, + }), + }, + { + name: "Alice can UPDATE a token's expiry_time property", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token2", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice can UPDATE a token's uses_allowed and expiry_time property", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 20, + "expiry_time": time.Now().Add(10*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice CANNOT update a token with invalid properties", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token2", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": -5, + "expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice CANNOT UPDATE a token that does not exist", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token9", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 100, + }, + ), + }, + { + name: "Alice can UPDATE token specifying uses_allowed as null - Valid for infinite uses", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": nil, + }, + ), + }, + { + name: "Alice can UPDATE token specifying expiry_time AS null - Valid for infinite time", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "expiry_time": nil, + }, + ), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token) + req := test.NewRequest(t, http.MethodPut, path) + if tc.requestOpt != nil { + req = test.NewRequest(t, http.MethodPut, path, tc.requestOpt) + } + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func getPointer[T any](s T) *T { + return &s +} + func TestAdminResetPassword(t *testing.T) { aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) @@ -55,10 +692,10 @@ func TestAdminResetPassword(t *testing.T) { AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) // Create the users in the userapi and login - accessTokens := map[*test.User]string{ - aliceAdmin: "", - bob: "", - vhUser: "", + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + vhUser: {}, } createAccessTokens(t, accessTokens, userAPI, ctx, routers) @@ -104,7 +741,7 @@ func TestAdminResetPassword(t *testing.T) { } if tc.withHeader { - req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser]) + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) } rec := httptest.NewRecorder() @@ -124,7 +761,7 @@ func TestPurgeRoom(t *testing.T) { room := test.NewRoom(t, aliceAdmin, test.RoomPreset(test.PresetTrustedPrivateChat)) // Invite Bob - room.CreateAndInsert(t, aliceAdmin, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, aliceAdmin, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) @@ -134,7 +771,11 @@ func TestPurgeRoom(t *testing.T) { cfg, processCtx, close := testrig.CreateConfig(t, dbType) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} - defer close() + defer func() { + // give components the time to process purge requests + time.Sleep(time.Millisecond * 50) + close() + }() routers := httputil.NewRouters() cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) @@ -143,8 +784,8 @@ func TestPurgeRoom(t *testing.T) { // this starts the JetStream consumers syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) - federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) - rsAPI.SetFederationAPI(nil, nil) + fsAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) + rsAPI.SetFederationAPI(fsAPI, nil) // Create the room if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { @@ -155,8 +796,8 @@ func TestPurgeRoom(t *testing.T) { AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) // Create the users in the userapi and login - accessTokens := map[*test.User]string{ - aliceAdmin: "", + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, } createAccessTokens(t, accessTokens, userAPI, ctx, routers) @@ -175,7 +816,7 @@ func TestPurgeRoom(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/purgeRoom/"+tc.roomID) - req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin].accessToken) rec := httptest.NewRecorder() routers.DendriteAdmin.ServeHTTP(rec, req) @@ -195,7 +836,7 @@ func TestAdminEvacuateRoom(t *testing.T) { room := test.NewRoom(t, aliceAdmin) // Join Bob - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) @@ -225,8 +866,8 @@ func TestAdminEvacuateRoom(t *testing.T) { AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) // Create the users in the userapi and login - accessTokens := map[*test.User]string{ - aliceAdmin: "", + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, } createAccessTokens(t, accessTokens, userAPI, ctx, routers) @@ -244,7 +885,7 @@ func TestAdminEvacuateRoom(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/evacuateRoom/"+tc.roomID) - req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin].accessToken) rec := httptest.NewRecorder() routers.DendriteAdmin.ServeHTTP(rec, req) @@ -292,10 +933,10 @@ func TestAdminEvacuateUser(t *testing.T) { room2 := test.NewRoom(t, aliceAdmin) // Join Bob - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room2.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room2.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) @@ -328,8 +969,8 @@ func TestAdminEvacuateUser(t *testing.T) { AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) // Create the users in the userapi and login - accessTokens := map[*test.User]string{ - aliceAdmin: "", + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, } createAccessTokens(t, accessTokens, userAPI, ctx, routers) @@ -349,7 +990,7 @@ func TestAdminEvacuateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/evacuateUser/"+tc.userID) - req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin].accessToken) rec := httptest.NewRecorder() routers.DendriteAdmin.ServeHTTP(rec, req) @@ -410,8 +1051,8 @@ func TestAdminMarkAsStale(t *testing.T) { AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) // Create the users in the userapi and login - accessTokens := map[*test.User]string{ - aliceAdmin: "", + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, } createAccessTokens(t, accessTokens, userAPI, ctx, routers) @@ -429,7 +1070,7 @@ func TestAdminMarkAsStale(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/refreshDevices/"+tc.userID) - req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin].accessToken) rec := httptest.NewRecorder() routers.DendriteAdmin.ServeHTTP(rec, req) @@ -441,35 +1082,3 @@ func TestAdminMarkAsStale(t *testing.T) { } }) } - -func createAccessTokens(t *testing.T, accessTokens map[*test.User]string, userAPI uapi.UserInternalAPI, ctx context.Context, routers httputil.Routers) { - t.Helper() - for u := range accessTokens { - localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) - userRes := &uapi.PerformAccountCreationResponse{} - password := util.RandomString(8) - if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ - AccountType: u.AccountType, - Localpart: localpart, - ServerName: serverName, - Password: password, - }, userRes); err != nil { - t.Errorf("failed to create account: %s", err) - } - - req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ - "type": authtypes.LoginTypePassword, - "identifier": map[string]interface{}{ - "type": "m.id.user", - "user": u.ID, - }, - "password": password, - })) - rec := httptest.NewRecorder() - routers.Client.ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("failed to login: %s", rec.Body.String()) - } - accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() - } -} diff --git a/clientapi/api/api.go b/clientapi/api/api.go index 23974c8658..28ff593fcc 100644 --- a/clientapi/api/api.go +++ b/clientapi/api/api.go @@ -21,3 +21,11 @@ type ExtraPublicRoomsProvider interface { // Rooms returns the extra rooms. This is called on-demand by clients, so cache appropriately. Rooms() []fclient.PublicRoom } + +type RegistrationToken struct { + Token *string `json:"token"` + UsesAllowed *int32 `json:"uses_allowed"` + Pending *int32 `json:"pending"` + Completed *int32 `json:"completed"` + ExpiryTime *int64 `json:"expiry_time"` +} diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index 93345f4b9d..8fae45b8d6 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -23,8 +23,8 @@ import ( "net/http" "strings" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -58,7 +58,7 @@ func VerifyUserFromRequest( if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.MissingToken(err.Error()), + JSON: spec.MissingToken(err.Error()), } } var res api.QueryAccessTokenResponse @@ -68,21 +68,23 @@ func VerifyUserFromRequest( }, &res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccessToken failed") - jsonErr := jsonerror.InternalServerError() - return nil, &jsonErr + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if res.Err != "" { if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(res.Err), + JSON: spec.Forbidden(res.Err), } } } if res.Device == nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.UnknownToken("Unknown token"), + JSON: spec.UnknownToken("Unknown token"), } } return res.Device, nil diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go index 1cfe381569..0a4a12ae64 100644 --- a/clientapi/auth/login.go +++ b/clientapi/auth/login.go @@ -21,10 +21,10 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -38,7 +38,7 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.C if err != nil { err := &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + JSON: spec.BadJSON("Reading request body failed: " + err.Error()), } return nil, nil, err } @@ -50,7 +50,7 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.C if err := json.Unmarshal(reqBytes, &header); err != nil { err := &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + JSON: spec.BadJSON("Reading request body failed: " + err.Error()), } return nil, nil, err } @@ -77,7 +77,7 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.C default: err := util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("unhandled login type: " + header.Type), + JSON: spec.InvalidParam("unhandled login type: " + header.Type), } return nil, nil, &err } diff --git a/clientapi/auth/login_jwt.go b/clientapi/auth/login_jwt.go index 35c7d19486..3ed1657941 100644 --- a/clientapi/auth/login_jwt.go +++ b/clientapi/auth/login_jwt.go @@ -8,8 +8,8 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -41,7 +41,7 @@ func (t *LoginTypeTokenJwt) LoginFromJSON(ctx context.Context, reqBytes []byte) if r.Token == "" { return nil, nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Token field for JWT is missing"), + JSON: spec.Forbidden("Token field for JWT is missing"), } } c := &Claims{} @@ -56,14 +56,14 @@ func (t *LoginTypeTokenJwt) LoginFromJSON(ctx context.Context, reqBytes []byte) util.GetLogger(ctx).WithError(err).Error("jwt.ParseWithClaims failed") return nil, nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Couldn't parse JWT"), + JSON: spec.Forbidden("Couldn't parse JWT"), } } if !token.Valid { return nil, nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Invalid JWT"), + JSON: spec.Forbidden("Invalid JWT"), } } diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index acd3fb605b..0fef99731f 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -21,12 +21,12 @@ import ( "strings" "testing" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -111,13 +111,13 @@ func TestBadLoginFromJSONReader(t *testing.T) { Name string Body string - WantErrCode string + WantErrCode spec.MatrixErrorCode }{ - {Name: "empty", WantErrCode: "M_BAD_JSON"}, + {Name: "empty", WantErrCode: spec.ErrorBadJSON}, { Name: "badUnmarshal", Body: `badsyntaxJSON`, - WantErrCode: "M_BAD_JSON", + WantErrCode: spec.ErrorBadJSON, }, { Name: "badPassword", @@ -127,7 +127,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { "password": "invalidpassword", "device_id": "adevice" }`, - WantErrCode: "M_FORBIDDEN", + WantErrCode: spec.ErrorForbidden, }, { Name: "badToken", @@ -136,7 +136,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { "token": "invalidtoken", "device_id": "adevice" }`, - WantErrCode: "M_FORBIDDEN", + WantErrCode: spec.ErrorForbidden, }, { Name: "badType", @@ -144,7 +144,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { "type": "m.login.invalid", "device_id": "adevice" }`, - WantErrCode: "M_INVALID_ARGUMENT_VALUE", + WantErrCode: spec.ErrorInvalidParam, }, } for _, tst := range tsts { @@ -161,7 +161,7 @@ func TestBadLoginFromJSONReader(t *testing.T) { if errRes == nil { cleanup(ctx, nil) t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) - } else if merr, ok := errRes.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != tst.WantErrCode { + } else if merr, ok := errRes.JSON.(spec.MatrixError); ok && merr.ErrCode != tst.WantErrCode { t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) } }) diff --git a/clientapi/auth/login_token.go b/clientapi/auth/login_token.go index 293b9a4600..dd3ae8351d 100644 --- a/clientapi/auth/login_token.go +++ b/clientapi/auth/login_token.go @@ -20,9 +20,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -48,13 +48,15 @@ func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*L var res uapi.QueryLoginTokenResponse if err := t.UserAPI.QueryLoginToken(ctx, &uapi.QueryLoginTokenRequest{Token: r.Token}, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("UserAPI.QueryLoginToken failed") - jsonErr := jsonerror.InternalServerError() - return nil, nil, &jsonErr + return nil, nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if res.Data == nil { return nil, nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("invalid login token"), + JSON: spec.Forbidden("invalid login token"), } } diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 756a1b611d..fa708a5070 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -22,15 +22,14 @@ import ( "github.com/go-ldap/ldap/v3" "github.com/google/uuid" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -89,14 +88,16 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, }, &res) if err != nil { util.GetLogger(ctx).WithError(err).Error("userApi.QueryLocalpartForThreePID failed") - resp := jsonerror.InternalServerError() - return nil, &resp + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown(""), + } } username = "@" + res.Localpart + ":" + string(t.Config.Matrix.ServerName) if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.Forbidden("Invalid username or password"), + JSON: spec.Forbidden("Invalid username or password"), } } } else { @@ -105,26 +106,26 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("A username must be supplied."), + JSON: spec.BadJSON("A username must be supplied."), } } if len(r.Password) == 0 { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("A password must be supplied."), + JSON: spec.BadJSON("A password must be supplied."), } } localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.InvalidUsername(err.Error()), + JSON: spec.InvalidUsername(err.Error()), } } if !t.Config.Matrix.IsLocalServerName(domain) { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.InvalidUsername("The server name is not known."), + JSON: spec.InvalidUsername("The server name is not known."), } } @@ -135,7 +136,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, if !ok { return nil, &util.JSONResponse{ Code: http.StatusTooManyRequests, - JSON: jsonerror.LimitExceeded("Too Many Requests", retryIn.Milliseconds()), + JSON: spec.LimitExceeded("Too Many Requests", retryIn.Milliseconds()), } } } @@ -147,7 +148,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Unable to fetch account by password."), + JSON: spec.Unknown("Unable to fetch account by password."), } } @@ -177,7 +178,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, return &r.Login, nil } -func (t *LoginTypePassword) authenticateDb(ctx context.Context, localpart string, domain gomatrixserverlib.ServerName, password string) (*api.Account, *util.JSONResponse) { +func (t *LoginTypePassword) authenticateDb(ctx context.Context, localpart string, domain spec.ServerName, password string) (*api.Account, *util.JSONResponse) { res := &api.QueryAccountByPasswordResponse{} err := t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ Localpart: strings.ToLower(localpart), @@ -187,7 +188,7 @@ func (t *LoginTypePassword) authenticateDb(ctx context.Context, localpart string if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Unable to fetch account by password."), + JSON: spec.Unknown("Unable to fetch account by password."), } } @@ -202,7 +203,7 @@ func (t *LoginTypePassword) authenticateDb(ctx context.Context, localpart string if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Unable to fetch account by password."), + JSON: spec.Unknown("Unable to fetch account by password."), } } // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows @@ -213,7 +214,7 @@ func (t *LoginTypePassword) authenticateDb(ctx context.Context, localpart string } return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), + JSON: spec.Forbidden("The username or password was incorrect or the account does not exist."), } } } @@ -226,9 +227,10 @@ func (t *LoginTypePassword) authenticateLdap(username, password string) (bool, * if err != nil { return false, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("unable to connect to ldap: " + err.Error()), + JSON: spec.Unknown("unable to connect to ldap: " + err.Error()), } } + // nolint: errcheck defer conn.Close() if t.Config.Ldap.AdminBindEnabled { @@ -236,7 +238,7 @@ func (t *LoginTypePassword) authenticateLdap(username, password string) (bool, * if err != nil { return false, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("unable to bind to ldap: " + err.Error()), + JSON: spec.Unknown("unable to bind to ldap: " + err.Error()), } } filter := strings.ReplaceAll(t.Config.Ldap.SearchFilter, "{username}", username) @@ -249,19 +251,19 @@ func (t *LoginTypePassword) authenticateLdap(username, password string) (bool, * if err != nil { return false, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("unable to bind to search ldap: " + err.Error()), + JSON: spec.Unknown("unable to bind to search ldap: " + err.Error()), } } if len(result.Entries) > 1 { return false, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("'user' must be duplicated."), + JSON: spec.BadJSON("'user' must be duplicated."), } } if len(result.Entries) < 1 { return false, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("'user' not found."), + JSON: spec.BadJSON("'user' not found."), } } @@ -273,7 +275,7 @@ func (t *LoginTypePassword) authenticateLdap(username, password string) (bool, * if err != nil { return false, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.InvalidUsername(err.Error()), + JSON: spec.InvalidUsername(err.Error()), } } if t.Rt != nil { @@ -281,7 +283,7 @@ func (t *LoginTypePassword) authenticateLdap(username, password string) (bool, * } return false, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), + JSON: spec.Forbidden("The username or password was incorrect or the account does not exist."), } } } else { @@ -293,7 +295,7 @@ func (t *LoginTypePassword) authenticateLdap(username, password string) (bool, * if err != nil { return false, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.InvalidUsername(err.Error()), + JSON: spec.InvalidUsername(err.Error()), } } if t.Rt != nil { @@ -301,7 +303,7 @@ func (t *LoginTypePassword) authenticateLdap(username, password string) (bool, * } return false, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), + JSON: spec.Forbidden("The username or password was incorrect or the account does not exist."), } } } @@ -310,7 +312,7 @@ func (t *LoginTypePassword) authenticateLdap(username, password string) (bool, * if err != nil { return false, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.InvalidUsername(err.Error()), + JSON: spec.InvalidUsername(err.Error()), } } return isAdmin, nil @@ -335,7 +337,7 @@ func (t *LoginTypePassword) isLdapAdmin(conn *ldap.Conn, username string) (bool, return true, nil } -func (t *LoginTypePassword) getOrCreateAccount(ctx context.Context, localpart string, domain gomatrixserverlib.ServerName, admin bool) (*api.Account, *util.JSONResponse) { +func (t *LoginTypePassword) getOrCreateAccount(ctx context.Context, localpart string, domain spec.ServerName, admin bool) (*api.Account, *util.JSONResponse) { var existing api.QueryAccountByLocalpartResponse err := t.UserLoginAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ Localpart: localpart, @@ -348,7 +350,7 @@ func (t *LoginTypePassword) getOrCreateAccount(ctx context.Context, localpart st if err != sql.ErrNoRows { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.InvalidUsername(err.Error()), + JSON: spec.InvalidUsername(err.Error()), } } @@ -369,12 +371,12 @@ func (t *LoginTypePassword) getOrCreateAccount(ctx context.Context, localpart st if _, ok := err.(*api.ErrorConflict); ok { return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired user ID is already taken."), + JSON: spec.UserInUse("Desired user ID is already taken."), } } return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create account: " + err.Error()), + JSON: spec.Unknown("failed to create account: " + err.Error()), } } return created.Account, nil diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 850d3e586b..4da0746a51 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -20,9 +20,9 @@ import ( "net/http" "sync" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -180,8 +180,10 @@ func (u *UserInteractive) NewSession() *util.JSONResponse { sessionID, err := GenerateAccessToken() if err != nil { logrus.WithError(err).Error("failed to generate session ID") - res := jsonerror.InternalServerError() - return &res + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } u.Lock() u.Sessions[sessionID] = []string{} @@ -195,15 +197,19 @@ func (u *UserInteractive) ResponseWithChallenge(sessionID string, response inter mixedObjects := make(map[string]interface{}) b, err := json.Marshal(response) if err != nil { - ise := jsonerror.InternalServerError() - return &ise + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } _ = json.Unmarshal(b, &mixedObjects) challenge := u.challenge(sessionID) b, err = json.Marshal(challenge.JSON) if err != nil { - ise := jsonerror.InternalServerError() - return &ise + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } _ = json.Unmarshal(b, &mixedObjects) @@ -236,7 +242,7 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, _ *api.D if !ok { return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Unknown auth.type: " + authType), + JSON: spec.BadJSON("Unknown auth.type: " + authType), } } @@ -252,7 +258,7 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, _ *api.D if !u.IsSingleStageFlow(authType) { return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("The auth.session is missing or unknown."), + JSON: spec.Unknown("The auth.session is missing or unknown."), } } } diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 3b89bd3b39..9525ada514 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -8,14 +8,14 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) var ( ctx = context.Background() - serverName = gomatrixserverlib.ServerName("example.com") + serverName = spec.ServerName("example.com") // space separated localpart+password -> account lookup = make(map[string]*api.Account) device = &api.Device{ diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index dacd6718e9..fe597ffebf 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -37,7 +37,7 @@ func AddPublicRoutes( routers httputil.Routers, cfg *config.Dendrite, natsInstance *jetstream.NATSInstance, - federation *fclient.FederationClient, + federation fclient.FederationClient, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, transactionsCache *transactions.Cache, diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go new file mode 100644 index 0000000000..3a4ae4ff9e --- /dev/null +++ b/clientapi/clientapi_test.go @@ -0,0 +1,2133 @@ +package clientapi + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/pushrules" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/version" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type userDevice struct { + accessToken string + deviceID string + password string +} + +func TestGetPutDevices(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + testCases := []struct { + name string + requestUser *test.User + deviceUser *test.User + request *http.Request + wantStatusCode int + validateFunc func(t *testing.T, device userDevice, routers httputil.Routers) + }{ + { + name: "can get all devices", + requestUser: alice, + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices", strings.NewReader("")), + wantStatusCode: http.StatusOK, + }, + { + name: "can get specific own device", + requestUser: alice, + deviceUser: alice, + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + }, + { + name: "can not get device for different user", + requestUser: alice, + deviceUser: bob, + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices/", strings.NewReader("")), + wantStatusCode: http.StatusNotFound, + }, + { + name: "can update own device", + requestUser: alice, + deviceUser: alice, + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/devices/", strings.NewReader(`{"display_name":"my new displayname"}`)), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, device userDevice, routers httputil.Routers) { + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices/"+device.deviceID, strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+device.accessToken) + rec := httptest.NewRecorder() + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "display_name").Str + if gotDisplayName != "my new displayname" { + t.Fatalf("expected displayname '%s', got '%s'", "my new displayname", gotDisplayName) + } + }, + }, + { + // this should return "device does not exist" + name: "can not update device for different user", + requestUser: alice, + deviceUser: bob, + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/devices/", strings.NewReader(`{"display_name":"my new displayname"}`)), + wantStatusCode: http.StatusNotFound, + }, + } + + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dev := accessTokens[tc.requestUser] + if tc.deviceUser != nil { + tc.request = httptest.NewRequest(tc.request.Method, tc.request.RequestURI+accessTokens[tc.deviceUser].deviceID, tc.request.Body) + } + tc.request.Header.Set("Authorization", "Bearer "+dev.accessToken) + rec := httptest.NewRecorder() + routers.Client.ServeHTTP(rec, tc.request) + if rec.Code != tc.wantStatusCode { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + if tc.wantStatusCode != http.StatusOK && rec.Code != http.StatusOK { + return + } + if tc.validateFunc != nil { + tc.validateFunc(t, dev, routers) + } + }) + } + }) +} + +// Deleting devices requires the UIA dance, so do this in a different test +func TestDeleteDevice(t *testing.T) { + alice := test.NewUser(t) + localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + + // create the account and an initial device + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + // create some more devices + accessToken := util.RandomString(8) + devRes := &uapi.PerformDeviceCreationResponse{} + if err := userAPI.PerformDeviceCreation(processCtx.Context(), &uapi.PerformDeviceCreationRequest{ + Localpart: localpart, + ServerName: serverName, + AccessToken: accessToken, + NoDeviceListUpdate: true, + }, devRes); err != nil { + t.Fatal(err) + } + if !devRes.DeviceCreated { + t.Fatalf("failed to create device") + } + secondDeviceID := devRes.Device.ID + + // initiate UIA for the second device + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/devices/"+secondDeviceID, strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected HTTP 401, got %d: %s", rec.Code, rec.Body.String()) + } + // get the session ID + sessionID := gjson.GetBytes(rec.Body.Bytes(), "session").Str + + // prepare UIA request body + reqBody := bytes.Buffer{} + if err := json.NewEncoder(&reqBody).Encode(map[string]interface{}{ + "auth": map[string]string{ + "session": sessionID, + "type": authtypes.LoginTypePassword, + "user": alice.ID, + "password": accessTokens[alice].password, + }, + }); err != nil { + t.Fatal(err) + } + + // copy the request body, so we can use it again for the successful delete + reqBody2 := reqBody + + // do the same request again, this time with our UIA, but for a different device ID, this should fail + rec = httptest.NewRecorder() + + req = httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/devices/"+accessTokens[alice].deviceID, &reqBody) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusForbidden { + t.Fatalf("expected HTTP 403, got %d: %s", rec.Code, rec.Body.String()) + } + + // do the same request again, this time with our UIA, but for the correct device ID, this should be fine + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/devices/"+secondDeviceID, &reqBody2) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + // verify devices are deleted + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + for _, device := range gjson.GetBytes(rec.Body.Bytes(), "devices.#.device_id").Array() { + if device.Str == secondDeviceID { + t.Fatalf("expected device %s to be deleted, but wasn't", secondDeviceID) + } + } + }) +} + +// Deleting devices requires the UIA dance, so do this in a different test +func TestDeleteDevices(t *testing.T) { + alice := test.NewUser(t) + localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + + // create the account and an initial device + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + // create some more devices + var devices []string + for i := 0; i < 10; i++ { + accessToken := util.RandomString(8) + devRes := &uapi.PerformDeviceCreationResponse{} + if err := userAPI.PerformDeviceCreation(processCtx.Context(), &uapi.PerformDeviceCreationRequest{ + Localpart: localpart, + ServerName: serverName, + AccessToken: accessToken, + NoDeviceListUpdate: true, + }, devRes); err != nil { + t.Fatal(err) + } + if !devRes.DeviceCreated { + t.Fatalf("failed to create device") + } + devices = append(devices, devRes.Device.ID) + } + + // initiate UIA + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/delete_devices", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected HTTP 401, got %d: %s", rec.Code, rec.Body.String()) + } + // get the session ID + sessionID := gjson.GetBytes(rec.Body.Bytes(), "session").Str + + // prepare UIA request body + reqBody := bytes.Buffer{} + if err := json.NewEncoder(&reqBody).Encode(map[string]interface{}{ + "auth": map[string]string{ + "session": sessionID, + "type": authtypes.LoginTypePassword, + "user": alice.ID, + "password": accessTokens[alice].password, + }, + "devices": devices[5:], + }); err != nil { + t.Fatal(err) + } + + // do the same request again, this time with our UIA, + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/delete_devices", &reqBody) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + // verify devices are deleted + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + for _, device := range gjson.GetBytes(rec.Body.Bytes(), "devices.#.device_id").Array() { + for _, deletedDevice := range devices[5:] { + if device.Str == deletedDevice { + t.Fatalf("expected device %s to be deleted, but wasn't", deletedDevice) + } + } + } + }) +} + +func createAccessTokens(t *testing.T, accessTokens map[*test.User]userDevice, userAPI uapi.UserInternalAPI, ctx context.Context, routers httputil.Routers) { + t.Helper() + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + routers.Client.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = userDevice{ + accessToken: gjson.GetBytes(rec.Body.Bytes(), "access_token").String(), + deviceID: gjson.GetBytes(rec.Body.Bytes(), "device_id").String(), + password: password, + } + } +} + +func TestSetDisplayname(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + notLocalUser := &test.User{ID: "@charlie:localhost", Localpart: "charlie"} + changeDisplayName := "my new display name" + + testCases := []struct { + name string + user *test.User + wantOK bool + changeReq io.Reader + wantDisplayName string + }{ + { + name: "invalid user", + user: &test.User{ID: "!notauser"}, + }, + { + name: "non-existent user", + user: &test.User{ID: "@doesnotexist:test"}, + }, + { + name: "non-local user is not allowed", + user: notLocalUser, + }, + { + name: "existing user is allowed to change own name", + user: alice, + wantOK: true, + wantDisplayName: changeDisplayName, + }, + { + name: "existing user is not allowed to change own name if name is empty", + user: bob, + wantOK: false, + wantDisplayName: "", + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil) + asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI) + + AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + bob: {}, + } + + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wantDisplayName := tc.user.Localpart + if tc.changeReq == nil { + tc.changeReq = strings.NewReader("") + } + + // check profile after initial account creation + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID, strings.NewReader("")) + t.Logf("%s", req.URL.String()) + routers.Client.ServeHTTP(rec, req) + + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d", rec.Code) + } + + if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "displayname").Str; tc.wantOK && gotDisplayName != wantDisplayName { + t.Fatalf("expected displayname to be '%s', but got '%s'", wantDisplayName, gotDisplayName) + } + + // now set the new display name + wantDisplayName = tc.wantDisplayName + tc.changeReq = strings.NewReader(fmt.Sprintf(`{"displayname":"%s"}`, tc.wantDisplayName)) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/profile/"+tc.user.ID+"/displayname", tc.changeReq) + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.user].accessToken) + + routers.Client.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + // now only get the display name + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID+"/displayname", strings.NewReader("")) + + routers.Client.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "displayname").Str; tc.wantOK && gotDisplayName != wantDisplayName { + t.Fatalf("expected displayname to be '%s', but got '%s'", wantDisplayName, gotDisplayName) + } + }) + } + }) +} + +func TestSetAvatarURL(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + notLocalUser := &test.User{ID: "@charlie:localhost", Localpart: "charlie"} + changeDisplayName := "mxc://newMXID" + + testCases := []struct { + name string + user *test.User + wantOK bool + changeReq io.Reader + avatar_url string + }{ + { + name: "invalid user", + user: &test.User{ID: "!notauser"}, + }, + { + name: "non-existent user", + user: &test.User{ID: "@doesnotexist:test"}, + }, + { + name: "non-local user is not allowed", + user: notLocalUser, + }, + { + name: "existing user is allowed to change own avatar", + user: alice, + wantOK: true, + avatar_url: changeDisplayName, + }, + { + name: "existing user is not allowed to change own avatar if avatar is empty", + user: bob, + wantOK: false, + avatar_url: "", + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil) + asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI) + + AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + bob: {}, + } + + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wantAvatarURL := "" + if tc.changeReq == nil { + tc.changeReq = strings.NewReader("") + } + + // check profile after initial account creation + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID, strings.NewReader("")) + t.Logf("%s", req.URL.String()) + routers.Client.ServeHTTP(rec, req) + + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d", rec.Code) + } + + if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "avatar_url").Str; tc.wantOK && gotDisplayName != wantAvatarURL { + t.Fatalf("expected displayname to be '%s', but got '%s'", wantAvatarURL, gotDisplayName) + } + + // now set the new display name + wantAvatarURL = tc.avatar_url + tc.changeReq = strings.NewReader(fmt.Sprintf(`{"avatar_url":"%s"}`, tc.avatar_url)) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/profile/"+tc.user.ID+"/avatar_url", tc.changeReq) + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.user].accessToken) + + routers.Client.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + // now only get the display name + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID+"/avatar_url", strings.NewReader("")) + + routers.Client.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + + if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "avatar_url").Str; tc.wantOK && gotDisplayName != wantAvatarURL { + t.Fatalf("expected displayname to be '%s', but got '%s'", wantAvatarURL, gotDisplayName) + } + }) + } + }) +} + +func TestTyping(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + natsInstance := jetstream.NATSInstance{} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + // Needed to create accounts + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + typingForUser string + roomID string + requestBody io.Reader + wantOK bool + }{ + { + name: "can not set typing for different user", + typingForUser: "@notourself:test", + roomID: room.ID, + requestBody: strings.NewReader(""), + }, + { + name: "invalid request body", + typingForUser: alice.ID, + roomID: room.ID, + requestBody: strings.NewReader(""), + }, + { + name: "non-existent room", + typingForUser: alice.ID, + roomID: "!doesnotexist:test", + }, + { + name: "invalid room ID", + typingForUser: alice.ID, + roomID: "@notaroomid:test", + }, + { + name: "allowed to set own typing status", + typingForUser: alice.ID, + roomID: room.ID, + requestBody: strings.NewReader(`{"typing":true}`), + wantOK: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/rooms/"+tc.roomID+"/typing/"+tc.typingForUser, tc.requestBody) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestMembership(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RateLimiting.Enabled = false + defer close() + natsInstance := jetstream.NATSInstance{} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + // Needed to create accounts + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + rsAPI.SetUserAPI(userAPI) + // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + alice: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatal(err) + } + + invalidBodyRequest := func(roomID, membershipType string) *http.Request { + return httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", roomID, membershipType), strings.NewReader("")) + } + + missingUserIDRequest := func(roomID, membershipType string) *http.Request { + return httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", roomID, membershipType), strings.NewReader("{}")) + } + + testCases := []struct { + name string + roomID string + request *http.Request + wantOK bool + asUser *test.User + }{ + { + name: "ban - invalid request body", + request: invalidBodyRequest(room.ID, "ban"), + }, + { + name: "kick - invalid request body", + request: invalidBodyRequest(room.ID, "kick"), + }, + { + name: "unban - invalid request body", + request: invalidBodyRequest(room.ID, "unban"), + }, + { + name: "invite - invalid request body", + request: invalidBodyRequest(room.ID, "invite"), + }, + { + name: "ban - missing user_id body", + request: missingUserIDRequest(room.ID, "ban"), + }, + { + name: "kick - missing user_id body", + request: missingUserIDRequest(room.ID, "kick"), + }, + { + name: "unban - missing user_id body", + request: missingUserIDRequest(room.ID, "unban"), + }, + { + name: "invite - missing user_id body", + request: missingUserIDRequest(room.ID, "invite"), + }, + { + name: "Bob forgets invalid room", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", "!doesnotexist", "forget"), strings.NewReader("")), + asUser: bob, + }, + { + name: "Alice can not ban Bob in non-existent room", // fails because "not joined" + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", "!doesnotexist:test", "ban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + }, + { + name: "Alice can not kick Bob in non-existent room", // fails because "not joined" + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", "!doesnotexist:test", "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + }, + // the following must run in sequence, as they build up on each other + { + name: "Alice invites Bob", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "invite"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: true, + }, + { + name: "Bob accepts invite", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "join"), strings.NewReader("")), + wantOK: true, + asUser: bob, + }, + { + name: "Alice verifies that Bob is joined", // returns an error if no membership event can be found + request: httptest.NewRequest(http.MethodGet, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s/m.room.member/%s", room.ID, "state", bob.ID), strings.NewReader("")), + wantOK: true, + }, + { + name: "Bob forgets the room but is still a member", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "forget"), strings.NewReader("")), + wantOK: false, // user is still in the room + asUser: bob, + }, + { + name: "Bob can not kick Alice", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, alice.ID))), + wantOK: false, // powerlevel too low + asUser: bob, + }, + { + name: "Bob can not ban Alice", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "ban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, alice.ID))), + wantOK: false, // powerlevel too low + asUser: bob, + }, + { + name: "Alice can kick Bob", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: true, + }, + { + name: "Alice can ban Bob", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "ban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: true, + }, + { + name: "Alice can not kick Bob again", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: false, // can not kick banned/left user + }, + { + name: "Bob can not unban himself", // mostly because of not being a member of the room + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "unban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + asUser: bob, + }, + { + name: "Alice can not invite Bob again", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "invite"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: false, // user still banned + }, + { + name: "Alice can unban Bob", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "unban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: true, + }, + { + name: "Alice can not unban Bob again", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "unban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: false, + }, + { + name: "Alice can invite Bob again", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "invite"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))), + wantOK: true, + }, + { + name: "Bob can reject the invite by leaving", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "leave"), strings.NewReader("")), + wantOK: true, + asUser: bob, + }, + { + name: "Bob can forget the room", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "forget"), strings.NewReader("")), + wantOK: true, + asUser: bob, + }, + { + name: "Bob can forget the room again", + request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "forget"), strings.NewReader("")), + wantOK: true, + asUser: bob, + }, + // END must run in sequence + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.asUser == nil { + tc.asUser = alice + } + rec := httptest.NewRecorder() + tc.request.Header.Set("Authorization", "Bearer "+accessTokens[tc.asUser].accessToken) + routers.Client.ServeHTTP(rec, tc.request) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) + } + if !tc.wantOK && rec.Code == http.StatusOK { + t.Fatalf("expected request to fail, but didn't: %s", rec.Body.String()) + } + t.Logf("%s", rec.Body.String()) + }) + } + }) +} + +func TestCapabilities(t *testing.T) { + alice := test.NewUser(t) + ctx := context.Background() + + // construct the expected result + versionsMap := map[gomatrixserverlib.RoomVersion]string{} + for v, desc := range version.SupportedRoomVersions() { + if desc.Stable() { + versionsMap[v] = "stable" + } else { + versionsMap[v] = "unstable" + } + } + + expectedMap := map[string]interface{}{ + "capabilities": map[string]interface{}{ + "m.change_password": map[string]bool{ + "enabled": true, + }, + "m.room_versions": map[string]interface{}{ + "default": version.DefaultRoomVersion(), + "available": versionsMap, + }, + }, + } + + expectedBuf := &bytes.Buffer{} + err := json.NewEncoder(expectedBuf).Encode(expectedMap) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RateLimiting.Enabled = false + defer close() + natsInstance := jetstream.NATSInstance{} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + + // Needed to create accounts + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + testCases := []struct { + name string + request *http.Request + }{ + { + name: "can get capabilities", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/capabilities", strings.NewReader("")), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, tc.request) + assert.Equal(t, http.StatusOK, rec.Code) + assert.ObjectsAreEqual(expectedBuf.Bytes(), rec.Body.Bytes()) + }) + } + }) +} + +func TestTurnserver(t *testing.T) { + alice := test.NewUser(t) + ctx := context.Background() + + cfg, processCtx, close := testrig.CreateConfig(t, test.DBTypeSQLite) + cfg.ClientAPI.RateLimiting.Enabled = false + defer close() + natsInstance := jetstream.NATSInstance{} + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + + // Needed to create accounts + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + //rsAPI.SetUserAPI(userAPI) + // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + // Create the users in the userapi and login + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + + testCases := []struct { + name string + turnConfig config.TURN + wantEmptyResponse bool + }{ + { + name: "no turn server configured", + wantEmptyResponse: true, + }, + { + name: "servers configured but not userLifeTime", + wantEmptyResponse: true, + turnConfig: config.TURN{URIs: []string{""}}, + }, + { + name: "missing sharedSecret/username/password", + wantEmptyResponse: true, + turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m"}, + }, + { + name: "with shared secret", + turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", SharedSecret: "iAmSecret"}, + }, + { + name: "with username/password secret", + turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", Username: "username", Password: "iAmSecret"}, + }, + { + name: "only username set", + turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", Username: "username"}, + wantEmptyResponse: true, + }, + { + name: "only password set", + turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", Username: "username"}, + wantEmptyResponse: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/voip/turnServer", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + cfg.ClientAPI.TURN = tc.turnConfig + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + + if tc.wantEmptyResponse && rec.Body.String() != "{}" { + t.Fatalf("expected an empty response, but got %s", rec.Body.String()) + } + if !tc.wantEmptyResponse { + assert.NotEqual(t, "{}", rec.Body.String()) + + resp := gomatrix.RespTurnServer{} + err := json.NewDecoder(rec.Body).Decode(&resp) + assert.NoError(t, err) + + duration, _ := time.ParseDuration(tc.turnConfig.UserLifetime) + assert.Equal(t, tc.turnConfig.URIs, resp.URIs) + assert.Equal(t, int(duration.Seconds()), resp.TTL) + if tc.turnConfig.Username != "" && tc.turnConfig.Password != "" { + assert.Equal(t, tc.turnConfig.Username, resp.Username) + assert.Equal(t, tc.turnConfig.Password, resp.Password) + } + } + }) + } +} + +// TODO: Disable for now. Make it work soon. +// func Test3PID(t *testing.T) { +// alice := test.NewUser(t) +// ctx := context.Background() + +// test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { +// cfg, processCtx, close := testrig.CreateConfig(t, dbType) +// cfg.ClientAPI.RateLimiting.Enabled = false +// cfg.FederationAPI.DisableTLSValidation = true // needed to be able to connect to our identityServer below +// defer close() +// natsInstance := jetstream.NATSInstance{} + +// routers := httputil.NewRouters() +// cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + +// // Needed to create accounts +// rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) +// userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) +// // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. +// AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + +// // Create the users in the userapi and login +// accessTokens := map[*test.User]userDevice{ +// alice: {}, +// } +// createAccessTokens(t, accessTokens, userAPI, ctx, routers) + +// identityServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// switch { +// case strings.Contains(r.URL.String(), "getValidated3pid"): +// resp := threepid.GetValidatedResponse{} +// switch r.URL.Query().Get("client_secret") { +// case "fail": +// resp.ErrCode = string(spec.ErrorSessionNotValidated) +// case "fail2": +// resp.ErrCode = "some other error" +// case "fail3": +// _, _ = w.Write([]byte("{invalidJson")) +// return +// case "success": +// resp.Medium = "email" +// case "success2": +// resp.Medium = "email" +// resp.Address = "somerandom@address.com" +// } +// _ = json.NewEncoder(w).Encode(resp) +// case strings.Contains(r.URL.String(), "requestToken"): +// resp := threepid.SID{SID: "randomSID"} +// _ = json.NewEncoder(w).Encode(resp) +// } +// })) +// defer identityServer.Close() + +// identityServerBase := strings.TrimPrefix(identityServer.URL, "https://") + +// testCases := []struct { +// name string +// request *http.Request +// wantOK bool +// setTrustedServer bool +// wantLen3PIDs int +// }{ +// { +// name: "can get associated threepid info", +// request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/account/3pid", strings.NewReader("")), +// wantOK: true, +// }, +// { +// name: "can not set threepid info with invalid JSON", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader("")), +// }, +// { +// name: "can not set threepid info with untrusted server", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader("{}")), +// }, +// { +// name: "can check threepid info with trusted server, but unverified", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"fail"}}`, identityServerBase))), +// setTrustedServer: true, +// wantOK: false, +// }, +// { +// name: "can check threepid info with trusted server, but fails for some other reason", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"fail2"}}`, identityServerBase))), +// setTrustedServer: true, +// wantOK: false, +// }, +// { +// name: "can check threepid info with trusted server, but fails because of invalid json", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"fail3"}}`, identityServerBase))), +// setTrustedServer: true, +// wantOK: false, +// }, +// { +// name: "can save threepid info with trusted server", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"success"}}`, identityServerBase))), +// setTrustedServer: true, +// wantOK: true, +// }, +// { +// name: "can save threepid info with trusted server using bind=true", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"success2"},"bind":true}`, identityServerBase))), +// setTrustedServer: true, +// wantOK: true, +// }, +// { +// name: "can get associated threepid info again", +// request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/account/3pid", strings.NewReader("")), +// wantOK: true, +// wantLen3PIDs: 2, +// }, +// { +// name: "can delete associated threepid info", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/delete", strings.NewReader(`{"medium":"email","address":"somerandom@address.com"}`)), +// wantOK: true, +// }, +// { +// name: "can get associated threepid after deleting association", +// request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/account/3pid", strings.NewReader("")), +// wantOK: true, +// wantLen3PIDs: 1, +// }, +// { +// name: "can not request emailToken with invalid request body", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/email/requestToken", strings.NewReader("")), +// }, +// { +// name: "can not request emailToken for in use address", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/email/requestToken", strings.NewReader(fmt.Sprintf(`{"client_secret":"somesecret","email":"","send_attempt":1,"id_server":"%s"}`, identityServerBase))), +// }, +// { +// name: "can request emailToken", +// request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/email/requestToken", strings.NewReader(fmt.Sprintf(`{"client_secret":"somesecret","email":"somerandom@address.com","send_attempt":1,"id_server":"%s"}`, identityServerBase))), +// wantOK: true, +// }, +// } + +// for _, tc := range testCases { +// t.Run(tc.name, func(t *testing.T) { + +// if tc.setTrustedServer { +// cfg.Global.TrustedIDServers = []string{identityServerBase} +// } + +// rec := httptest.NewRecorder() +// tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + +// routers.Client.ServeHTTP(rec, tc.request) +// t.Logf("Response: %s", rec.Body.String()) +// if tc.wantOK && rec.Code != http.StatusOK { +// t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String()) +// } +// if !tc.wantOK && rec.Code == http.StatusOK { +// t.Fatalf("expected request to fail, but didn't: %s", rec.Body.String()) +// } +// if tc.wantLen3PIDs > 0 { +// var resp routing.ThreePIDsResponse +// if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { +// t.Fatal(err) +// } +// if len(resp.ThreePIDs) != tc.wantLen3PIDs { +// t.Fatalf("expected %d threepids, got %d", tc.wantLen3PIDs, len(resp.ThreePIDs)) +// } +// } +// }) +// } +// }) +// } + +func TestPushRules(t *testing.T) { + alice := test.NewUser(t) + + // create the default push rules, used when validating responses + localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) + defaultRules, err := json.Marshal(pushRuleSets) + assert.NoError(t, err) + + ruleID1 := "myrule" + ruleID2 := "myrule2" + ruleID3 := "myrule3" + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RateLimiting.Enabled = false + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + testCases := []struct { + name string + request *http.Request + wantStatusCode int + validateFunc func(t *testing.T, respBody *bytes.Buffer) // used when updating rules, otherwise wantStatusCode should be enough + queryAttr map[string]string + }{ + { + name: "can not get rules without trailing slash", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can get default rules", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.Equal(t, defaultRules, respBody.Bytes()) + }, + }, + { + name: "can get rules by scope", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.Equal(t, gjson.GetBytes(defaultRules, "global").Raw, respBody.String()) + }, + }, + { + name: "can not get invalid rules by scope", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get rules for invalid scope and kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/invalid/", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get rules for invalid kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can get rules by scope and kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.Equal(t, gjson.GetBytes(defaultRules, "global.override").Raw, respBody.String()) + }, + }, + { + name: "can get rules by scope and content kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.Equal(t, gjson.GetBytes(defaultRules, "global.content").Raw, respBody.String()) + }, + }, + { + name: "can not get rules by scope and room kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/room/", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get rules by scope and sender kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/sender/", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can get rules by scope and underride kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/underride/", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.Equal(t, gjson.GetBytes(defaultRules, "global.underride").Raw, respBody.String()) + }, + }, + { + name: "can not get rules by scope, kind and ID for invalid scope", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/doesnotexist/.m.rule.master", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get rules by scope, kind and ID for invalid kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/doesnotexist/.m.rule.master", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can get rules by scope, kind and ID", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master", strings.NewReader("")), + wantStatusCode: http.StatusOK, + }, + { + name: "can not get rules by scope, kind and ID for invalid ID", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.doesnotexist", strings.NewReader("")), + wantStatusCode: http.StatusNotFound, + }, + { + name: "can not get status for invalid attribute", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/invalid", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get status for invalid kind", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get enabled status for invalid scope", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not get enabled status for invalid rule", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/doesnotexist/enabled", strings.NewReader("")), + wantStatusCode: http.StatusNotFound, + }, + { + name: "can get enabled rules by scope, kind and ID", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + assert.False(t, gjson.GetBytes(respBody.Bytes(), "enabled").Bool(), "expected master rule to be disabled") + }, + }, + { + name: "can get actions scope, kind and ID", + request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader("")), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + actions := gjson.GetBytes(respBody.Bytes(), "actions").Array() + // only a basic check + assert.Equal(t, 1, len(actions)) + }, + }, + { + name: "can not set enabled status with invalid JSON", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not set attribute for invalid attribute", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/doesnotexist", strings.NewReader("{}")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not set attribute for invalid scope", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("{}")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not set attribute for invalid kind", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("{}")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not set attribute for invalid rule", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/invalid/enabled", strings.NewReader("{}")), + wantStatusCode: http.StatusNotFound, + }, + { + name: "can set enabled status with valid JSON", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(`{"enabled":true}`)), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + assert.True(t, gjson.GetBytes(rec.Body.Bytes(), "enabled").Bool(), "expected master rule to be enabled: %s", rec.Body.String()) + }, + }, + { + name: "can set actions with valid JSON", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(`{"actions":["dont_notify","notify"]}`)), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + assert.Equal(t, 2, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 2 actions %s", rec.Body.String()) + }, + }, + { + name: "can not create new push rule with invalid JSON", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not create new push rule with invalid rule content", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("{}")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not create new push rule with invalid scope", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can create new push rule with valid rule content", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule/actions", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + assert.Equal(t, 1, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 1 action %s", rec.Body.String()) + }, + }, + { + name: "can not create new push starting with a dot", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/.myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can create new push rule after existing", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)), + queryAttr: map[string]string{ + "after": ruleID1, + }, + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + rules := gjson.ParseBytes(rec.Body.Bytes()) + for i, rule := range rules.Array() { + if rule.Get("rule_id").Str == ruleID1 && i != 0 { + t.Fatalf("expected '%s' to be the first, but wasn't", ruleID1) + } + if rule.Get("rule_id").Str == ruleID2 && i != 1 { + t.Fatalf("expected '%s' to be the second, but wasn't", ruleID2) + } + } + }, + }, + { + name: "can create new push rule before existing", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule3", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)), + queryAttr: map[string]string{ + "before": ruleID1, + }, + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + rules := gjson.ParseBytes(rec.Body.Bytes()) + for i, rule := range rules.Array() { + if rule.Get("rule_id").Str == ruleID3 && i != 0 { + t.Fatalf("expected '%s' to be the first, but wasn't", ruleID3) + } + if rule.Get("rule_id").Str == ruleID1 && i != 1 { + t.Fatalf("expected '%s' to be the second, but wasn't", ruleID1) + } + if rule.Get("rule_id").Str == ruleID2 && i != 2 { + t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1) + } + } + }, + }, + { + name: "can modify existing push rule", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)), + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule2/actions", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + actions := gjson.GetBytes(rec.Body.Bytes(), "actions").Array() + // there should only be one action + assert.Equal(t, "dont_notify", actions[0].Str) + }, + }, + { + name: "can move existing push rule to the front", + request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)), + queryAttr: map[string]string{ + "before": ruleID3, + }, + wantStatusCode: http.StatusOK, + validateFunc: func(t *testing.T, respBody *bytes.Buffer) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + rules := gjson.ParseBytes(rec.Body.Bytes()) + for i, rule := range rules.Array() { + if rule.Get("rule_id").Str == ruleID2 && i != 0 { + t.Fatalf("expected '%s' to be the first, but wasn't", ruleID2) + } + if rule.Get("rule_id").Str == ruleID3 && i != 1 { + t.Fatalf("expected '%s' to be the second, but wasn't", ruleID3) + } + if rule.Get("rule_id").Str == ruleID1 && i != 2 { + t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1) + } + } + }, + }, + { + name: "can not delete push rule with invalid scope", + request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/invalid/content/myrule2", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not delete push rule with invalid kind", + request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/invalid/myrule2", strings.NewReader("")), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "can not delete push rule with non-existent rule", + request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/doesnotexist", strings.NewReader("")), + wantStatusCode: http.StatusNotFound, + }, + { + name: "can delete existing push rule", + request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader("")), + wantStatusCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + + if tc.queryAttr != nil { + params := url.Values{} + for k, v := range tc.queryAttr { + params.Set(k, v) + } + + tc.request = httptest.NewRequest(tc.request.Method, tc.request.URL.String()+"?"+params.Encode(), tc.request.Body) + } + + tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + + routers.Client.ServeHTTP(rec, tc.request) + assert.Equal(t, tc.wantStatusCode, rec.Code, rec.Body.String()) + if tc.validateFunc != nil { + tc.validateFunc(t, rec.Body) + } + t.Logf("%s", rec.Body.String()) + }) + } + }) +} + +// Tests the `/keys` endpoints. +// Note that this only tests the happy path. +func TestKeys(t *testing.T) { + alice := test.NewUser(t) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RateLimiting.Enabled = false + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + // Start a TLSServer with our client mux + srv := httptest.NewTLSServer(routers.Client) + defer srv.Close() + + cl, err := mautrix.NewClient(srv.URL, id.UserID(alice.ID), accessTokens[alice].accessToken) + if err != nil { + t.Fatal(err) + } + // Set the client so the self-signed certificate is trusted + cl.Client = srv.Client() + cl.DeviceID = id.DeviceID(accessTokens[alice].deviceID) + + cs := crypto.NewMemoryStore(nil) + oc := crypto.NewOlmMachine(cl, nil, cs, dummyStore{}) + if err = oc.Load(); err != nil { + t.Fatal(err) + } + + // tests `/keys/upload` + if err = oc.ShareKeys(ctx, 0); err != nil { + t.Fatal(err) + } + + // tests `/keys/device_signing/upload` + _, err = oc.GenerateAndUploadCrossSigningKeys(accessTokens[alice].password, "passphrase") + if err != nil { + t.Fatal(err) + } + + // tests `/keys/query` + dev, err := oc.GetOrFetchDevice(ctx, id.UserID(alice.ID), id.DeviceID(accessTokens[alice].deviceID)) + if err != nil { + t.Fatal(err) + } + + // Validate that the keys returned from the server are what the client has stored + oi := oc.OwnIdentity() + if oi.SigningKey != dev.SigningKey { + t.Fatalf("expected signing key '%s', got '%s'", oi.SigningKey, dev.SigningKey) + } + if oi.IdentityKey != dev.IdentityKey { + t.Fatalf("expected identity '%s', got '%s'", oi.IdentityKey, dev.IdentityKey) + } + + // tests `/keys/signatures/upload` + if err = oc.SignOwnMasterKey(); err != nil { + t.Fatal(err) + } + + // tests `/keys/claim` + otks := make(map[string]map[string]string) + otks[alice.ID] = map[string]string{ + accessTokens[alice].deviceID: string(id.KeyAlgorithmSignedCurve25519), + } + + data, err := json.Marshal(claimKeysRequest{OneTimeKeys: otks}) + if err != nil { + t.Fatal(err) + } + req, err := http.NewRequest(http.MethodPost, srv.URL+"/_matrix/client/v3/keys/claim", bytes.NewBuffer(data)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + resp, err := srv.Client().Do(req) + if err != nil { + t.Fatal(err) + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + if !gjson.GetBytes(respBody, "one_time_keys."+alice.ID+"."+string(dev.DeviceID)).Exists() { + t.Fatalf("expected one time keys for alice, but didn't find any: %s", string(respBody)) + } + }) +} + +type claimKeysRequest struct { + // The keys to be claimed. A map from user ID, to a map from device ID to algorithm name. + OneTimeKeys map[string]map[string]string `json:"one_time_keys"` +} + +type dummyStore struct{} + +func (d dummyStore) IsEncrypted(roomID id.RoomID) bool { + return true +} + +func (d dummyStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent { + return &event.EncryptionEventContent{} +} + +func (d dummyStore) FindSharedRooms(userID id.UserID) []id.RoomID { + return []id.RoomID{} +} + +func TestKeyBackup(t *testing.T) { + alice := test.NewUser(t) + + handleResponseCode := func(t *testing.T, rec *httptest.ResponseRecorder, expectedCode int) { + t.Helper() + if rec.Code != expectedCode { + t.Fatalf("expected HTTP %d, but got %d: %s", expectedCode, rec.Code, rec.Body.String()) + } + } + + testCases := []struct { + name string + request func(t *testing.T) *http.Request + validate func(t *testing.T, rec *httptest.ResponseRecorder) + }{ + { + name: "can not create backup with invalid JSON", + request: func(t *testing.T) *http.Request { + reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"`) // missing closing braces + return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusBadRequest) + }, + }, + { + name: "can not create backup with missing auth_data", // as this would result in MarshalJSON errors when querying again + request: func(t *testing.T) *http.Request { + reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"}`) + return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusBadRequest) + }, + }, + { + name: "can create backup", + request: func(t *testing.T) *http.Request { + reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1","auth_data":{"data":"random"}}`) + return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + wantVersion := "1" + if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion { + t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion) + } + }, + }, + { + name: "can not query backup for invalid version", + request: func(t *testing.T) *http.Request { + return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/1337", nil) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusNotFound) + }, + }, + { + name: "can not query backup for invalid version string", + request: func(t *testing.T) *http.Request { + return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/notanumber", nil) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusNotFound) + }, + }, + { + name: "can query backup", + request: func(t *testing.T) *http.Request { + return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version", nil) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + wantVersion := "1" + if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion { + t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion) + } + }, + }, + { + name: "can query backup without returning rooms", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) > 0 { + t.Fatalf("expected no rooms in version, but got %#v", gotRooms) + } + }, + }, + { + name: "can query backup for invalid room", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + if gotSessions := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotSessions) > 0 { + t.Fatalf("expected no sessions in version, but got %#v", gotSessions) + } + }, + }, + { + name: "can not query backup for invalid session", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test/doesnotexist", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusNotFound) + }, + }, + { + name: "can not update backup with missing version", + request: func(t *testing.T) *http.Request { + return test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys") + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusBadRequest) + }, + }, + { + name: "can not update backup with invalid data", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, "") + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{ + "version": "0", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusBadRequest) + }, + }, + { + name: "can not update backup with wrong version", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, map[string]interface{}{ + "rooms": map[string]interface{}{ + "!testroom:test": map[string]interface{}{ + "sessions": map[string]uapi.KeyBackupSession{}, + }, + }, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{ + "version": "5", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusForbidden) + }, + }, + { + name: "can update backup with correct version", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, map[string]interface{}{ + "rooms": map[string]interface{}{ + "!testroom:test": map[string]interface{}{ + "sessions": map[string]uapi.KeyBackupSession{ + "dummySession": { + FirstMessageIndex: 1, + }, + }, + }, + }, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + }, + }, + { + name: "can update backup with correct version for specific room", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, map[string]interface{}{ + "sessions": map[string]uapi.KeyBackupSession{ + "dummySession": { + FirstMessageIndex: 1, + IsVerified: true, + SessionData: json.RawMessage("{}"), + }, + }, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test", reqBody, test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + t.Logf("%#v", rec.Body.String()) + }, + }, + { + name: "can update backup with correct version for specific room and session", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{ + FirstMessageIndex: 1, + SessionData: json.RawMessage("{}"), + IsVerified: true, + ForwardedCount: 0, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", reqBody, test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + }, + }, + { + name: "can update backup by version", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{ + FirstMessageIndex: 1, + SessionData: json.RawMessage("{}"), + IsVerified: true, + ForwardedCount: 0, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/1", reqBody, test.WithQueryParams(map[string]string{"version": "1"})) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + t.Logf("%#v", rec.Body.String()) + }, + }, + { + name: "can not update backup by version for invalid version", + request: func(t *testing.T) *http.Request { + reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{ + FirstMessageIndex: 1, + SessionData: json.RawMessage("{}"), + IsVerified: true, + ForwardedCount: 0, + }) + req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/2", reqBody, test.WithQueryParams(map[string]string{"version": "1"})) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + }, + }, + { + name: "can query backup sessions", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) != 1 { + t.Fatalf("expected one room in response, but got %#v", rec.Body.String()) + } + }, + }, + { + name: "can query backup sessions by room", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotRooms) != 1 { + t.Fatalf("expected one session in response, but got %#v", rec.Body.String()) + } + }, + }, + { + name: "can query backup sessions by room and sessionID", + request: func(t *testing.T) *http.Request { + req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", test.WithQueryParams(map[string]string{ + "version": "1", + })) + return req + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + if !gjson.GetBytes(rec.Body.Bytes(), "is_verified").Bool() { + t.Fatalf("expected session to be verified, but wasn't: %#v", rec.Body.String()) + } + }, + }, + { + name: "can not delete invalid version backup", + request: func(t *testing.T) *http.Request { + return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/2", nil) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusNotFound) + }, + }, + { + name: "can delete version backup", + request: func(t *testing.T) *http.Request { + return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + }, + }, + { + name: "deleting the same backup version twice doesn't error", + request: func(t *testing.T) *http.Request { + return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusOK) + }, + }, + { + name: "deleting an empty version doesn't work", // make sure we can't delete an empty backup version. Handled at the router level + request: func(t *testing.T) *http.Request { + return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/", nil) + }, + validate: func(t *testing.T, rec *httptest.ResponseRecorder) { + handleResponseCode(t, rec, http.StatusNotFound) + }, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RateLimiting.Enabled = false + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + natsInstance := jetstream.NATSInstance{} + defer close() + + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + } + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + req := tc.request(t) + req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken) + routers.Client.ServeHTTP(rec, req) + tc.validate(t, rec) + }) + } + }) +} diff --git a/clientapi/httputil/httputil.go b/clientapi/httputil/httputil.go index 74f84f1e73..d9f4423231 100644 --- a/clientapi/httputil/httputil.go +++ b/clientapi/httputil/httputil.go @@ -20,7 +20,7 @@ import ( "net/http" "unicode/utf8" - "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -32,8 +32,10 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon body, err := io.ReadAll(req.Body) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("io.ReadAll failed") - resp := jsonerror.InternalServerError() - return &resp + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return UnmarshalJSON(body, iface) @@ -43,7 +45,7 @@ func UnmarshalJSON(body []byte, iface interface{}) *util.JSONResponse { if !utf8.Valid(body) { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("Body contains invalid UTF-8"), + JSON: spec.NotJSON("Body contains invalid UTF-8"), } } @@ -53,7 +55,7 @@ func UnmarshalJSON(body []byte, iface interface{}) *util.JSONResponse { // valid JSON with incorrect types for values. return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } return nil diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go deleted file mode 100644 index be7d13a96c..0000000000 --- a/clientapi/jsonerror/jsonerror.go +++ /dev/null @@ -1,229 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package jsonerror - -import ( - "context" - "fmt" - "net/http" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" -) - -// MatrixError represents the "standard error response" in Matrix. -// http://matrix.org/docs/spec/client_server/r0.2.0.html#api-standards -type MatrixError struct { - ErrCode string `json:"errcode"` - Err string `json:"error"` -} - -func (e MatrixError) Error() string { - return fmt.Sprintf("%s: %s", e.ErrCode, e.Err) -} - -// InternalServerError returns a 500 Internal Server Error in a matrix-compliant -// format. -func InternalServerError() util.JSONResponse { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: Unknown("Internal Server Error"), - } -} - -// Unknown is an unexpected error -func Unknown(msg string) *MatrixError { - return &MatrixError{"M_UNKNOWN", msg} -} - -// Forbidden is an error when the client tries to access a resource -// they are not allowed to access. -func Forbidden(msg string) *MatrixError { - return &MatrixError{"M_FORBIDDEN", msg} -} - -// BadJSON is an error when the client supplies malformed JSON. -func BadJSON(msg string) *MatrixError { - return &MatrixError{"M_BAD_JSON", msg} -} - -// BadAlias is an error when the client supplies a bad alias. -func BadAlias(msg string) *MatrixError { - return &MatrixError{"M_BAD_ALIAS", msg} -} - -// NotJSON is an error when the client supplies something that is not JSON -// to a JSON endpoint. -func NotJSON(msg string) *MatrixError { - return &MatrixError{"M_NOT_JSON", msg} -} - -// NotFound is an error when the client tries to access an unknown resource. -func NotFound(msg string) *MatrixError { - return &MatrixError{"M_NOT_FOUND", msg} -} - -// MissingArgument is an error when the client tries to access a resource -// without providing an argument that is required. -func MissingArgument(msg string) *MatrixError { - return &MatrixError{"M_MISSING_ARGUMENT", msg} -} - -// InvalidArgumentValue is an error when the client tries to provide an -// invalid value for a valid argument -func InvalidArgumentValue(msg string) *MatrixError { - return &MatrixError{"M_INVALID_ARGUMENT_VALUE", msg} -} - -// MissingToken is an error when the client tries to access a resource which -// requires authentication without supplying credentials. -func MissingToken(msg string) *MatrixError { - return &MatrixError{"M_MISSING_TOKEN", msg} -} - -// UnknownToken is an error when the client tries to access a resource which -// requires authentication and supplies an unrecognised token -func UnknownToken(msg string) *MatrixError { - return &MatrixError{"M_UNKNOWN_TOKEN", msg} -} - -// WeakPassword is an error which is returned when the client tries to register -// using a weak password. http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based -func WeakPassword(msg string) *MatrixError { - return &MatrixError{"M_WEAK_PASSWORD", msg} -} - -// InvalidUsername is an error returned when the client tries to register an -// invalid username -func InvalidUsername(msg string) *MatrixError { - return &MatrixError{"M_INVALID_USERNAME", msg} -} - -// UserInUse is an error returned when the client tries to register an -// username that already exists -func UserInUse(msg string) *MatrixError { - return &MatrixError{"M_USER_IN_USE", msg} -} - -// RoomInUse is an error returned when the client tries to make a room -// that already exists -func RoomInUse(msg string) *MatrixError { - return &MatrixError{"M_ROOM_IN_USE", msg} -} - -// ASExclusive is an error returned when an application service tries to -// register an username that is outside of its registered namespace, or if a -// user attempts to register a username or room alias within an exclusive -// namespace. -func ASExclusive(msg string) *MatrixError { - return &MatrixError{"M_EXCLUSIVE", msg} -} - -// GuestAccessForbidden is an error which is returned when the client is -// forbidden from accessing a resource as a guest. -func GuestAccessForbidden(msg string) *MatrixError { - return &MatrixError{"M_GUEST_ACCESS_FORBIDDEN", msg} -} - -// InvalidSignature is an error which is returned when the client tries -// to upload invalid signatures. -func InvalidSignature(msg string) *MatrixError { - return &MatrixError{"M_INVALID_SIGNATURE", msg} -} - -// InvalidParam is an error that is returned when a parameter was invalid, -// traditionally with cross-signing. -func InvalidParam(msg string) *MatrixError { - return &MatrixError{"M_INVALID_PARAM", msg} -} - -// MissingParam is an error that is returned when a parameter was incorrect, -// traditionally with cross-signing. -func MissingParam(msg string) *MatrixError { - return &MatrixError{"M_MISSING_PARAM", msg} -} - -// UnableToAuthoriseJoin is an error that is returned when a server can't -// determine whether to allow a restricted join or not. -func UnableToAuthoriseJoin(msg string) *MatrixError { - return &MatrixError{"M_UNABLE_TO_AUTHORISE_JOIN", msg} -} - -// LeaveServerNoticeError is an error returned when trying to reject an invite -// for a server notice room. -func LeaveServerNoticeError() *MatrixError { - return &MatrixError{ - ErrCode: "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM", - Err: "You cannot reject this invite", - } -} - -type IncompatibleRoomVersionError struct { - RoomVersion string `json:"room_version"` - Error string `json:"error"` - Code string `json:"errcode"` -} - -// IncompatibleRoomVersion is an error which is returned when the client -// requests a room with a version that is unsupported. -func IncompatibleRoomVersion(roomVersion gomatrixserverlib.RoomVersion) *IncompatibleRoomVersionError { - return &IncompatibleRoomVersionError{ - Code: "M_INCOMPATIBLE_ROOM_VERSION", - RoomVersion: string(roomVersion), - Error: "Your homeserver does not support the features required to join this room", - } -} - -// UnsupportedRoomVersion is an error which is returned when the client -// requests a room with a version that is unsupported. -func UnsupportedRoomVersion(msg string) *MatrixError { - return &MatrixError{"M_UNSUPPORTED_ROOM_VERSION", msg} -} - -// LimitExceededError is a rate-limiting error. -type LimitExceededError struct { - MatrixError - RetryAfterMS int64 `json:"retry_after_ms,omitempty"` -} - -// LimitExceeded is an error when the client tries to send events too quickly. -func LimitExceeded(msg string, retryAfterMS int64) *LimitExceededError { - return &LimitExceededError{ - MatrixError: MatrixError{"M_LIMIT_EXCEEDED", msg}, - RetryAfterMS: retryAfterMS, - } -} - -// NotTrusted is an error which is returned when the client asks the server to -// proxy a request (e.g. 3PID association) to a server that isn't trusted -func NotTrusted(serverName string) *MatrixError { - return &MatrixError{ - ErrCode: "M_SERVER_NOT_TRUSTED", - Err: fmt.Sprintf("Untrusted server '%s'", serverName), - } -} - -// InternalAPIError is returned when Dendrite failed to reach an internal API. -func InternalAPIError(ctx context.Context, err error) util.JSONResponse { - logrus.WithContext(ctx).WithError(err).Error("Error reaching an internal API") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: &MatrixError{ - ErrCode: "M_INTERNAL_SERVER_ERROR", - Err: "Dendrite encountered an error reaching an internal API.", - }, - } -} diff --git a/clientapi/jsonerror/jsonerror_test.go b/clientapi/jsonerror/jsonerror_test.go deleted file mode 100644 index 9f3754cbc5..0000000000 --- a/clientapi/jsonerror/jsonerror_test.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package jsonerror - -import ( - "encoding/json" - "testing" -) - -func TestLimitExceeded(t *testing.T) { - e := LimitExceeded("too fast", 5000) - jsonBytes, err := json.Marshal(&e) - if err != nil { - t.Fatalf("TestLimitExceeded: Failed to marshal LimitExceeded error. %s", err.Error()) - } - want := `{"errcode":"M_LIMIT_EXCEEDED","error":"too fast","retry_after_ms":5000}` - if string(jsonBytes) != want { - t.Errorf("TestLimitExceeded: want %s, got %s", want, string(jsonBytes)) - } -} - -func TestForbidden(t *testing.T) { - e := Forbidden("you shall not pass") - jsonBytes, err := json.Marshal(&e) - if err != nil { - t.Fatalf("TestForbidden: Failed to marshal Forbidden error. %s", err.Error()) - } - want := `{"errcode":"M_FORBIDDEN","error":"you shall not pass"}` - if string(jsonBytes) != want { - t.Errorf("TestForbidden: want %s, got %s", want, string(jsonBytes)) - } -} diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index d0ea084182..7399397b7c 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -22,6 +22,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -38,13 +39,13 @@ type SyncAPIProducer struct { TopicPresenceEvent string TopicMultiRoomCast string JetStream nats.JetStreamContext - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName UserAPI userapi.ClientUserAPI } func (p *SyncAPIProducer) SendReceipt( ctx context.Context, - userID, roomID, eventID, receiptType string, timestamp gomatrixserverlib.Timestamp, + userID, roomID, eventID, receiptType string, timestamp spec.Timestamp, ) error { m := &nats.Msg{ Subject: p.TopicReceiptEvent, @@ -155,7 +156,7 @@ func (p *SyncAPIProducer) SendPresence( m.Header.Set("status_msg", *statusMsg) } - m.Header.Set("last_active_ts", strconv.Itoa(int(gomatrixserverlib.AsTimestamp(time.Now())))) + m.Header.Set("last_active_ts", strconv.Itoa(int(spec.AsTimestamp(time.Now())))) _, err := p.JetStream.PublishMsg(m, nats.Context(ctx)) return err diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 4742b12409..81afc3b13f 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -21,11 +21,11 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/internal/eventutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -38,7 +38,7 @@ func GetAccountData( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not match the current user"), + JSON: spec.Forbidden("userID does not match the current user"), } } @@ -69,7 +69,7 @@ func GetAccountData( return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("data not found"), + JSON: spec.NotFound("data not found"), } } @@ -81,7 +81,7 @@ func SaveAccountData( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not match the current user"), + JSON: spec.Forbidden("userID does not match the current user"), } } @@ -90,27 +90,30 @@ func SaveAccountData( if req.Body == http.NoBody { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("Content not JSON"), + JSON: spec.NotJSON("Content not JSON"), } } if dataType == "m.fully_read" || dataType == "m.push_rules" { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(fmt.Sprintf("Unable to modify %q using this API", dataType)), + JSON: spec.Forbidden(fmt.Sprintf("Unable to modify %q using this API", dataType)), } } body, err := io.ReadAll(req.Body) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("io.ReadAll failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !json.Valid(body) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Bad JSON content"), + JSON: spec.BadJSON("Bad JSON content"), } } @@ -142,8 +145,16 @@ func SaveReadMarker( userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, ) util.JSONResponse { + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("userID for this device is invalid"), + } + } + // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } @@ -157,7 +168,10 @@ func SaveReadMarker( if r.FullyRead != "" { data, err := json.Marshal(fullyReadEvent{EventID: r.FullyRead}) if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } dataReq := api.InputAccountDataRequest{ diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 76e18f2f8d..519666076e 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -3,112 +3,328 @@ package routing import ( "context" "encoding/json" + "errors" "fmt" "net/http" + "regexp" + "strconv" "time" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" + "golang.org/x/exp/constraints" - "github.com/matrix-org/dendrite/clientapi/jsonerror" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/internal/httputil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" ) -func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) +var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") + +func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + if !cfg.RegistrationRequiresToken { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("Registration via tokens is not enabled on this homeserver"), + } + } + request := struct { + Token string `json:"token"` + UsesAllowed *int32 `json:"uses_allowed,omitempty"` + ExpiryTime *int64 `json:"expiry_time,omitempty"` + Length int32 `json:"length"` + }{} + + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)), + } + } + + token := request.Token + usesAllowed := request.UsesAllowed + expiryTime := request.ExpiryTime + length := request.Length + + if len(token) == 0 { + if length == 0 { + // length not provided in request. Assign default value of 16. + length = 16 + } + // token not present in request body. Hence, generate a random token. + if length <= 0 || length > 64 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("length must be greater than zero and not greater than 64"), + } + } + token = util.RandomString(int(length)) + } + + if len(token) > 64 { + //Token present in request body, but is too long. + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("token must not be longer than 64"), + } + } + + isTokenValid := validRegistrationTokenRegex.Match([]byte(token)) + if !isTokenValid { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("token must consist only of characters matched by the regex [A-Za-z0-9-_]"), + } + } + // At this point, we have a valid token, either through request body or through random generation. + if usesAllowed != nil && *usesAllowed < 0 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), + } + } + if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("expiry_time must not be in the past"), + } + } + pending := int32(0) + completed := int32(0) + // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB) + registrationToken := &clientapi.RegistrationToken{ + Token: &token, + UsesAllowed: usesAllowed, + Pending: &pending, + Completed: &completed, + ExpiryTime: expiryTime, + } + created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken) + if !created { + return util.JSONResponse{ + Code: http.StatusConflict, + JSON: map[string]string{ + "error": fmt.Sprintf("token: %s already exists", token), + }, + } + } if err != nil { - return util.ErrorResponse(err) + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } } - res := &roomserverAPI.PerformAdminEvacuateRoomResponse{} - if err := rsAPI.PerformAdminEvacuateRoom( - req.Context(), - &roomserverAPI.PerformAdminEvacuateRoomRequest{ - RoomID: vars["roomID"], + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "token": token, + "uses_allowed": getReturnValue(usesAllowed), + "pending": pending, + "completed": completed, + "expiry_time": getReturnValue(expiryTime), }, - res, - ); err != nil { - return util.ErrorResponse(err) } - if err := res.Error; err != nil { - return err.JSONResponse() +} + +func getReturnValue[t constraints.Integer](in *t) any { + if in == nil { + return nil + } + return *in +} + +func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + queryParams := req.URL.Query() + returnAll := true + valid := true + validQuery, ok := queryParams["valid"] + if ok { + returnAll = false + validValue, err := strconv.ParseBool(validQuery[0]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("invalid 'valid' query parameter"), + } + } + valid = validValue + } + tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.ErrorUnknown, + } } return util.JSONResponse{ Code: 200, JSON: map[string]interface{}{ - "affected": res.Affected, + "registration_tokens": tokens, }, } } -func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - userID := vars["userID"] + tokenText := vars["token"] + token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } + } + return util.JSONResponse{ + Code: 200, + JSON: token, + } +} - _, domain, err := gomatrixserverlib.SplitID('@', userID) +func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) + return util.ErrorResponse(err) + } + tokenText := vars["token"] + err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{}, + } +} + +func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) } - if !cfg.Matrix.IsLocalServerName(domain) { + tokenText := vars["token"] + request := make(map[string]*int64) + if err = json.NewDecoder(req.Body).Decode(&request); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("User ID must belong to this server."), + JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)), } } - res := &roomserverAPI.PerformAdminEvacuateUserResponse{} - if err := rsAPI.PerformAdminEvacuateUser( - req.Context(), - &roomserverAPI.PerformAdminEvacuateUserRequest{ - UserID: userID, - }, - res, - ); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) + newAttributes := make(map[string]interface{}) + usesAllowed, ok := request["uses_allowed"] + if ok { + // Only add usesAllowed to newAtrributes if it is present and valid + if usesAllowed != nil && *usesAllowed < 0 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), + } + } + newAttributes["usesAllowed"] = usesAllowed + } + expiryTime, ok := request["expiry_time"] + if ok { + // Only add expiryTime to newAtrributes if it is present and valid + if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("expiry_time must not be in the past"), + } + } + newAttributes["expiryTime"] = expiryTime + } + if len(newAttributes) == 0 { + // No attributes to update. Return existing token + return AdminGetRegistrationToken(req, cfg, userAPI) + } + updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } + } + return util.JSONResponse{ + Code: 200, + JSON: *updatedToken, + } +} + +func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) } - if err := res.Error; err != nil { - return err.JSONResponse() + + affected, err := rsAPI.PerformAdminEvacuateRoom(req.Context(), vars["roomID"]) + switch err.(type) { + case nil: + case eventutil.ErrRoomNoExists: + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(err.Error()), + } + default: + logrus.WithError(err).WithField("roomID", vars["roomID"]).Error("Failed to evacuate room") + return util.ErrorResponse(err) } return util.JSONResponse{ Code: 200, JSON: map[string]interface{}{ - "affected": res.Affected, + "affected": affected, }, } } -func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateUser(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - roomID := vars["roomID"] - res := &roomserverAPI.PerformAdminPurgeRoomResponse{} - if err := rsAPI.PerformAdminPurgeRoom( - context.Background(), - &roomserverAPI.PerformAdminPurgeRoomRequest{ - RoomID: roomID, + affected, err := rsAPI.PerformAdminEvacuateUser(req.Context(), vars["userID"]) + if err != nil { + logrus.WithError(err).WithField("userID", vars["userID"]).Error("Failed to evacuate user") + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "affected": affected, }, - res, - ); err != nil { + } +} + +func AdminPurgeRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { return util.ErrorResponse(err) } - if err := res.Error; err != nil { - return err.JSONResponse() + + if err = rsAPI.PerformAdminPurgeRoom(context.Background(), vars["roomID"]); err != nil { + return util.ErrorResponse(err) } + return util.JSONResponse{ Code: 200, - JSON: res, + JSON: struct{}{}, } } @@ -116,7 +332,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De if req.Body == nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Missing request body"), + JSON: spec.Unknown("Missing request body"), } } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -129,7 +345,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } accAvailableResp := &api.QueryAccountAvailabilityResponse{} @@ -139,28 +355,29 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De }, accAvailableResp); err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalAPIError(req.Context(), err), + JSON: spec.InternalServerError{}, } } if accAvailableResp.Available { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.Unknown("User does not exist"), + JSON: spec.Unknown("User does not exist"), } } request := struct { - Password string `json:"password"` + Password string `json:"password"` + LogoutDevices bool `json:"logout_devices"` }{} if err = json.NewDecoder(req.Body).Decode(&request); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()), + JSON: spec.Unknown("Failed to decode request body: " + err.Error()), } } if request.Password == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting non-empty password."), + JSON: spec.MissingParam("Expecting non-empty password."), } } @@ -172,13 +389,13 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.De Localpart: localpart, ServerName: serverName, Password: request.Password, - LogoutDevices: true, + LogoutDevices: request.LogoutDevices, } updateRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Failed to perform password update: " + err.Error()), + JSON: spec.Unknown("Failed to perform password update: " + err.Error()), } } return util.JSONResponse{ @@ -195,7 +412,10 @@ func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *api.Device, _, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10) if err != nil { logrus.WithError(err).Error("failed to publish nats message") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, @@ -217,7 +437,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien if cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam("Can not mark local device list as stale"), + JSON: spec.InvalidParam("Can not mark local device list as stale"), } } @@ -228,7 +448,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown(fmt.Sprintf("Failed to mark device list as stale: %s", err)), + JSON: spec.Unknown(fmt.Sprintf("Failed to mark device list as stale: %s", err)), } } return util.JSONResponse{ @@ -237,7 +457,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien } } -func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminDownloadState(req *http.Request, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -246,33 +466,32 @@ func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.De if !ok { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting room ID."), + JSON: spec.MissingParam("Expecting room ID."), } } serverName, ok := vars["serverName"] if !ok { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting remote server name."), + JSON: spec.MissingParam("Expecting remote server name."), } } - res := &roomserverAPI.PerformAdminDownloadStateResponse{} - if err := rsAPI.PerformAdminDownloadState( - req.Context(), - &roomserverAPI.PerformAdminDownloadStateRequest{ - UserID: device.UserID, - RoomID: roomID, - ServerName: gomatrixserverlib.ServerName(serverName), - }, - res, - ); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } - if err := res.Error; err != nil { - return err.JSONResponse() + if err = rsAPI.PerformAdminDownloadState(req.Context(), roomID, device.UserID, spec.ServerName(serverName)); err != nil { + if errors.Is(err, eventutil.ErrRoomNoExists{}) { + return util.JSONResponse{ + Code: 200, + JSON: spec.NotFound(err.Error()), + } + } + logrus.WithError(err).WithFields(logrus.Fields{ + "userID": device.UserID, + "serverName": serverName, + "roomID": roomID, + }).Error("failed to download state") + return util.ErrorResponse(err) } return util.JSONResponse{ Code: 200, - JSON: map[string]interface{}{}, + JSON: struct{}{}, } } diff --git a/clientapi/routing/admin_whois.go b/clientapi/routing/admin_whois.go index f1cbd34678..7d7536564e 100644 --- a/clientapi/routing/admin_whois.go +++ b/clientapi/routing/admin_whois.go @@ -17,8 +17,8 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -51,7 +51,7 @@ func GetAdminWhois( if !allowed { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not match the current user"), + JSON: spec.Forbidden("userID does not match the current user"), } } @@ -61,7 +61,10 @@ func GetAdminWhois( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("GetAdminWhois failed to query user devices") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } devices := make(map[string]deviceInfo) diff --git a/clientapi/routing/aliases.go b/clientapi/routing/aliases.go index 68d0f41959..2d6b72d3ea 100644 --- a/clientapi/routing/aliases.go +++ b/clientapi/routing/aliases.go @@ -15,14 +15,14 @@ package routing import ( + "encoding/json" "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" - + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -31,7 +31,7 @@ func GetAliases( req *http.Request, rsAPI api.ClientRoomserverAPI, device *userapi.Device, roomID string, ) util.JSONResponse { stateTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomHistoryVisibility, + EventType: spec.MRoomHistoryVisibility, StateKey: "", } stateReq := &api.QueryCurrentStateRequest{ @@ -47,26 +47,37 @@ func GetAliases( visibility := gomatrixserverlib.HistoryVisibilityInvited if historyVisEvent, ok := stateRes.StateEvents[stateTuple]; ok { var err error - visibility, err = historyVisEvent.HistoryVisibility() - if err != nil { + var content gomatrixserverlib.HistoryVisibilityContent + if err = json.Unmarshal(historyVisEvent.Content(), &content); err != nil { util.GetLogger(req.Context()).WithError(err).Error("historyVisEvent.HistoryVisibility failed") return util.ErrorResponse(fmt.Errorf("historyVisEvent.HistoryVisibility: %w", err)) } + visibility = content.HistoryVisibility } - if visibility != gomatrixserverlib.WorldReadable { + if visibility != spec.WorldReadable { + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } queryReq := api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *deviceUserID, } var queryRes api.QueryMembershipForUserResponse if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !queryRes.IsInRoom { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You aren't a member of this room."), + JSON: spec.Forbidden("You aren't a member of this room."), } } } diff --git a/clientapi/routing/capabilities.go b/clientapi/routing/capabilities.go index e6c1a9b8cb..fa50fa1aa5 100644 --- a/clientapi/routing/capabilities.go +++ b/clientapi/routing/capabilities.go @@ -27,7 +27,7 @@ import ( func GetCapabilities() util.JSONResponse { versionsMap := map[gomatrixserverlib.RoomVersion]string{} for v, desc := range version.SupportedRoomVersions() { - if desc.Stable { + if desc.Stable() { versionsMap[v] = "stable" } else { versionsMap[v] = "unstable" diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index a0d80903dd..320f236cb1 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -26,10 +26,9 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverVersion "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -38,33 +37,19 @@ import ( // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-createroom type createRoomRequest struct { - Invite []string `json:"invite"` - Name string `json:"name"` - Visibility string `json:"visibility"` - Topic string `json:"topic"` - Preset string `json:"preset"` - CreationContent json.RawMessage `json:"creation_content"` - InitialState []fledglingEvent `json:"initial_state"` - RoomAliasName string `json:"room_alias_name"` - GuestCanJoin bool `json:"guest_can_join"` - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"` - IsDirect bool `json:"is_direct"` + Invite []string `json:"invite"` + Name string `json:"name"` + Visibility string `json:"visibility"` + Topic string `json:"topic"` + Preset string `json:"preset"` + CreationContent json.RawMessage `json:"creation_content"` + InitialState []gomatrixserverlib.FledglingEvent `json:"initial_state"` + RoomAliasName string `json:"room_alias_name"` + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"` + IsDirect bool `json:"is_direct"` } -const ( - presetPrivateChat = "private_chat" - presetTrustedPrivateChat = "trusted_private_chat" - presetPublicChat = "public_chat" -) - -const ( - historyVisibilityShared = "shared" - // TODO: These should be implemented once history visibility is implemented - // historyVisibilityWorldReadable = "world_readable" - // historyVisibilityInvited = "invited" -) - func (r createRoomRequest) Validate() *util.JSONResponse { whitespace := "\t\n\x0b\x0c\r " // https://docs.python.org/2/library/string.html#string.whitespace // https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/handlers/room.py#L81 @@ -72,28 +57,23 @@ func (r createRoomRequest) Validate() *util.JSONResponse { if strings.ContainsAny(r.RoomAliasName, whitespace+":") { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("room_alias_name cannot contain whitespace or ':'"), + JSON: spec.BadJSON("room_alias_name cannot contain whitespace or ':'"), } } for _, userID := range r.Invite { - // TODO: We should put user ID parsing code into gomatrixserverlib and use that instead - // (see https://github.com/matrix-org/gomatrixserverlib/blob/3394e7c7003312043208aa73727d2256eea3d1f6/eventcontent.go#L347 ) - // It should be a struct (with pointers into a single string to avoid copying) and - // we should update all refs to use UserID types rather than strings. - // https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/types.py#L92 - if _, _, err := gomatrixserverlib.SplitID('@', userID); err != nil { + if _, err := spec.NewUserID(userID, true); err != nil { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("user id must be in the form @localpart:domain"), + JSON: spec.BadJSON("user id must be in the form @localpart:domain"), } } } switch r.Preset { - case presetPrivateChat, presetTrustedPrivateChat, presetPublicChat, "": + case spec.PresetPrivateChat, spec.PresetTrustedPrivateChat, spec.PresetPublicChat, "": default: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("preset must be any of 'private_chat', 'trusted_private_chat', 'public_chat'"), + JSON: spec.BadJSON("preset must be any of 'private_chat', 'trusted_private_chat', 'public_chat'"), } } @@ -105,7 +85,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { if err != nil { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("malformed creation_content"), + JSON: spec.BadJSON("malformed creation_content"), } } @@ -114,7 +94,7 @@ func (r createRoomRequest) Validate() *util.JSONResponse { if err != nil { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("malformed creation_content"), + JSON: spec.BadJSON("malformed creation_content"), } } @@ -127,13 +107,6 @@ type createRoomResponse struct { RoomAlias string `json:"room_alias,omitempty"` // in synapse not spec } -// fledglingEvent is a helper representation of an event used when creating many events in succession. -type fledglingEvent struct { - Type string `json:"type"` - StateKey string `json:"state_key"` - Content interface{} `json:"content"` -} - // CreateRoom implements /createRoom func CreateRoom( req *http.Request, device *api.Device, @@ -141,456 +114,124 @@ func CreateRoom( profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, ) util.JSONResponse { - var r createRoomRequest - resErr := httputil.UnmarshalJSONRequest(req, &r) + var createRequest createRoomRequest + resErr := httputil.UnmarshalJSONRequest(req, &createRequest) if resErr != nil { return *resErr } - if resErr = r.Validate(); resErr != nil { + if resErr = createRequest.Validate(); resErr != nil { return *resErr } evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } - return createRoom(req.Context(), r, device, cfg, profileAPI, rsAPI, asAPI, evTime) + return createRoom(req.Context(), createRequest, device, cfg, profileAPI, rsAPI, asAPI, evTime) } // createRoom implements /createRoom -// nolint: gocyclo func createRoom( ctx context.Context, - r createRoomRequest, device *api.Device, + createRequest createRoomRequest, device *api.Device, cfg *config.ClientAPI, profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, ) util.JSONResponse { - _, userDomain, err := gomatrixserverlib.SplitID('@', device.UserID) + userID, err := spec.NewUserID(device.UserID, true) if err != nil { - util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + util.GetLogger(ctx).WithError(err).Error("invalid userID") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - if !cfg.Matrix.IsLocalServerName(userDomain) { + if !cfg.Matrix.IsLocalServerName(userID.Domain()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(fmt.Sprintf("User domain %q not configured locally", userDomain)), + JSON: spec.Forbidden(fmt.Sprintf("User domain %q not configured locally", userID.Domain())), } } - // TODO (#267): Check room ID doesn't clash with an existing one, and we - // probably shouldn't be using pseudo-random strings, maybe GUIDs? - roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) - logger := util.GetLogger(ctx) - userID := device.UserID + + // TODO: Check room ID doesn't clash with an existing one, and we + // probably shouldn't be using pseudo-random strings, maybe GUIDs? + roomID, err := spec.NewRoomID(fmt.Sprintf("!%s:%s", util.RandomString(16), userID.Domain())) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("invalid roomID") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } // Clobber keys: creator, room_version roomVersion := roomserverVersion.DefaultRoomVersion() - if r.RoomVersion != "" { - candidateVersion := gomatrixserverlib.RoomVersion(r.RoomVersion) + if createRequest.RoomVersion != "" { + candidateVersion := gomatrixserverlib.RoomVersion(createRequest.RoomVersion) _, roomVersionError := roomserverVersion.SupportedRoomVersion(candidateVersion) if roomVersionError != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(roomVersionError.Error()), + JSON: spec.UnsupportedRoomVersion(roomVersionError.Error()), } } roomVersion = candidateVersion } - // TODO: visibility/presets/raw initial state - // TODO: Create room alias association - // Make sure this doesn't fall into an application service's namespace though! - logger.WithFields(log.Fields{ - "userID": userID, - "roomID": roomID, + "userID": userID.String(), + "roomID": roomID.String(), "roomVersion": roomVersion, }).Info("Creating new room") - profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI) + profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID.String(), asAPI, profileAPI) if err != nil { util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") - return jsonerror.InternalServerError() - } - - createContent := map[string]interface{}{} - if len(r.CreationContent) > 0 { - if err = json.Unmarshal(r.CreationContent, &createContent); err != nil { - util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("invalid create content"), - } - } - } - createContent["creator"] = userID - createContent["room_version"] = roomVersion - powerLevelContent := eventutil.InitialPowerLevelsContent(userID) - joinRuleContent := gomatrixserverlib.JoinRuleContent{ - JoinRule: gomatrixserverlib.Invite, - } - historyVisibilityContent := gomatrixserverlib.HistoryVisibilityContent{ - HistoryVisibility: historyVisibilityShared, - } - - if r.PowerLevelContentOverride != nil { - // Merge powerLevelContentOverride fields by unmarshalling it atop the defaults - err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("malformed power_level_content_override"), - } - } - } - - switch r.Preset { - case presetPrivateChat: - joinRuleContent.JoinRule = gomatrixserverlib.Invite - historyVisibilityContent.HistoryVisibility = historyVisibilityShared - case presetTrustedPrivateChat: - joinRuleContent.JoinRule = gomatrixserverlib.Invite - historyVisibilityContent.HistoryVisibility = historyVisibilityShared - for _, invitee := range r.Invite { - powerLevelContent.Users[invitee] = 100 - } - case presetPublicChat: - joinRuleContent.JoinRule = gomatrixserverlib.Public - historyVisibilityContent.HistoryVisibility = historyVisibilityShared - } - - createEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomCreate, - Content: createContent, - } - powerLevelEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomPowerLevels, - Content: powerLevelContent, - } - joinRuleEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomJoinRules, - Content: joinRuleContent, - } - historyVisibilityEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomHistoryVisibility, - Content: historyVisibilityContent, - } - membershipEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomMember, - StateKey: userID, - Content: gomatrixserverlib.MemberContent{ - Membership: gomatrixserverlib.Join, - DisplayName: profile.DisplayName, - AvatarURL: profile.AvatarURL, - }, - } - - var nameEvent *fledglingEvent - var topicEvent *fledglingEvent - var guestAccessEvent *fledglingEvent - var aliasEvent *fledglingEvent - - if r.Name != "" { - nameEvent = &fledglingEvent{ - Type: gomatrixserverlib.MRoomName, - Content: eventutil.NameContent{ - Name: r.Name, - }, - } - } - - if r.Topic != "" { - topicEvent = &fledglingEvent{ - Type: gomatrixserverlib.MRoomTopic, - Content: eventutil.TopicContent{ - Topic: r.Topic, - }, - } - } - - if r.GuestCanJoin { - guestAccessEvent = &fledglingEvent{ - Type: gomatrixserverlib.MRoomGuestAccess, - Content: eventutil.GuestAccessContent{ - GuestAccess: "can_join", - }, - } - } - - var roomAlias string - if r.RoomAliasName != "" { - roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, userDomain) - // check it's free TODO: This races but is better than nothing - hasAliasReq := roomserverAPI.GetRoomIDForAliasRequest{ - Alias: roomAlias, - IncludeAppservices: false, - } - - var aliasResp roomserverAPI.GetRoomIDForAliasResponse - err = rsAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") - return jsonerror.InternalServerError() - } - if aliasResp.RoomID != "" { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.RoomInUse("Room ID already exists."), - } - } - - aliasEvent = &fledglingEvent{ - Type: gomatrixserverlib.MRoomCanonicalAlias, - Content: eventutil.CanonicalAlias{ - Alias: roomAlias, - }, - } - } - - var initialStateEvents []fledglingEvent - for i := range r.InitialState { - if r.InitialState[i].StateKey != "" { - initialStateEvents = append(initialStateEvents, r.InitialState[i]) - continue - } - - switch r.InitialState[i].Type { - case gomatrixserverlib.MRoomCreate: - continue - - case gomatrixserverlib.MRoomPowerLevels: - powerLevelEvent = r.InitialState[i] - - case gomatrixserverlib.MRoomJoinRules: - joinRuleEvent = r.InitialState[i] - - case gomatrixserverlib.MRoomHistoryVisibility: - historyVisibilityEvent = r.InitialState[i] - - case gomatrixserverlib.MRoomGuestAccess: - guestAccessEvent = &r.InitialState[i] - - case gomatrixserverlib.MRoomName: - nameEvent = &r.InitialState[i] - - case gomatrixserverlib.MRoomTopic: - topicEvent = &r.InitialState[i] - - default: - initialStateEvents = append(initialStateEvents, r.InitialState[i]) + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } - // send events into the room in order of: - // 1- m.room.create - // 2- room creator join member - // 3- m.room.power_levels - // 4- m.room.join_rules - // 5- m.room.history_visibility - // 6- m.room.canonical_alias (opt) - // 7- m.room.guest_access (opt) - // 8- other initial state items - // 9- m.room.name (opt) - // 10- m.room.topic (opt) - // 11- invite events (opt) - with is_direct flag if applicable TODO - // 12- 3pid invite events (opt) TODO - // This differs from Synapse slightly. Synapse would vary the ordering of 3-7 - // depending on if those events were in "initial_state" or not. This made it - // harder to reason about, hence sticking to a strict static ordering. - // TODO: Synapse has txn/token ID on each event. Do we need to do this here? - eventsToMake := []fledglingEvent{ - createEvent, membershipEvent, powerLevelEvent, joinRuleEvent, historyVisibilityEvent, - } - if guestAccessEvent != nil { - eventsToMake = append(eventsToMake, *guestAccessEvent) - } - eventsToMake = append(eventsToMake, initialStateEvents...) - if nameEvent != nil { - eventsToMake = append(eventsToMake, *nameEvent) - } - if topicEvent != nil { - eventsToMake = append(eventsToMake, *topicEvent) - } - if aliasEvent != nil { - // TODO: bit of a chicken and egg problem here as the alias doesn't exist and cannot until we have made the room. - // This means we might fail creating the alias but say the canonical alias is something that doesn't exist. - eventsToMake = append(eventsToMake, *aliasEvent) - } - - // TODO: invite events - // TODO: 3pid invite events + userDisplayName := profile.DisplayName + userAvatarURL := profile.AvatarURL - var builtEvents []*gomatrixserverlib.HeaderedEvent - authEvents := gomatrixserverlib.NewAuthEvents(nil) - for i, e := range eventsToMake { - depth := i + 1 // depth starts at 1 + keyID := cfg.Matrix.KeyID + privateKey := cfg.Matrix.PrivateKey - builder := gomatrixserverlib.EventBuilder{ - Sender: userID, - RoomID: roomID, - Type: e.Type, - StateKey: &e.StateKey, - Depth: int64(depth), - } - err = builder.SetContent(e.Content) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") - return jsonerror.InternalServerError() - } - if i > 0 { - builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} - } - var ev *gomatrixserverlib.Event - ev, err = buildEvent(&builder, userDomain, &authEvents, cfg, evTime, roomVersion) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("buildEvent failed") - return jsonerror.InternalServerError() - } + req := roomserverAPI.PerformCreateRoomRequest{ + InvitedUsers: createRequest.Invite, + RoomName: createRequest.Name, + Visibility: createRequest.Visibility, + Topic: createRequest.Topic, + StatePreset: createRequest.Preset, + CreationContent: createRequest.CreationContent, + InitialState: createRequest.InitialState, + RoomAliasName: createRequest.RoomAliasName, + RoomVersion: roomVersion, + PowerLevelContentOverride: createRequest.PowerLevelContentOverride, + IsDirect: createRequest.IsDirect, - if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { - util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") - return jsonerror.InternalServerError() - } - - // Add the event to the list of auth events - builtEvents = append(builtEvents, ev.Headered(roomVersion)) - err = authEvents.AddEvent(ev) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed") - return jsonerror.InternalServerError() - } + UserDisplayName: userDisplayName, + UserAvatarURL: userAvatarURL, + KeyID: keyID, + PrivateKey: privateKey, + EventTime: evTime, } - inputs := make([]roomserverAPI.InputRoomEvent, 0, len(builtEvents)) - for _, event := range builtEvents { - inputs = append(inputs, roomserverAPI.InputRoomEvent{ - Kind: roomserverAPI.KindNew, - Event: event, - Origin: userDomain, - SendAsServer: roomserverAPI.DoNotSendToOtherServers, - }) - } - if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil { - util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") - return jsonerror.InternalServerError() - } - - // TODO(#269): Reserve room alias while we create the room. This stops us - // from creating the room but still failing due to the alias having already - // been taken. - if roomAlias != "" { - aliasReq := roomserverAPI.SetRoomAliasRequest{ - Alias: roomAlias, - RoomID: roomID, - UserID: userID, - } - - var aliasResp roomserverAPI.SetRoomAliasResponse - err = rsAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed") - return jsonerror.InternalServerError() - } - - if aliasResp.AliasExists { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.RoomInUse("Room alias already exists."), - } - } - } - - // If this is a direct message then we should invite the participants. - if len(r.Invite) > 0 { - // Build some stripped state for the invite. - var globalStrippedState []gomatrixserverlib.InviteV2StrippedState - for _, event := range builtEvents { - // Chosen events from the spec: - // https://spec.matrix.org/v1.3/client-server-api/#stripped-state - switch event.Type() { - case gomatrixserverlib.MRoomCreate: - fallthrough - case gomatrixserverlib.MRoomName: - fallthrough - case gomatrixserverlib.MRoomAvatar: - fallthrough - case gomatrixserverlib.MRoomTopic: - fallthrough - case gomatrixserverlib.MRoomCanonicalAlias: - fallthrough - case gomatrixserverlib.MRoomEncryption: - fallthrough - case gomatrixserverlib.MRoomMember: - fallthrough - case gomatrixserverlib.MRoomJoinRules: - ev := event.Event - globalStrippedState = append( - globalStrippedState, - gomatrixserverlib.NewInviteV2StrippedState(ev), - ) - } - } - - // Process the invites. - for _, invitee := range r.Invite { - // Build the invite event. - inviteEvent, err := buildMembershipEvent( - ctx, invitee, "", profileAPI, device, gomatrixserverlib.Invite, - roomID, r.IsDirect, cfg, evTime, rsAPI, asAPI, - ) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") - continue - } - inviteStrippedState := append( - globalStrippedState, - gomatrixserverlib.NewInviteV2StrippedState(inviteEvent.Event), - ) - // Send the invite event to the roomserver. - var inviteRes roomserverAPI.PerformInviteResponse - event := inviteEvent.Headered(roomVersion) - if err := rsAPI.PerformInvite(ctx, &roomserverAPI.PerformInviteRequest{ - Event: event, - InviteRoomState: inviteStrippedState, - RoomVersion: event.RoomVersion, - SendAsServer: string(userDomain), - }, &inviteRes); err != nil { - util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), - } - } - if inviteRes.Error != nil { - return inviteRes.Error.JSONResponse() - } - } - } - - if r.Visibility == "public" { - // expose this room in the published room list - var pubRes roomserverAPI.PerformPublishResponse - if err := rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{ - RoomID: roomID, - Visibility: "public", - }, &pubRes); err != nil { - return jsonerror.InternalAPIError(ctx, err) - } - if pubRes.Error != nil { - // treat as non-fatal since the room is already made by this point - util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public") - } + roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req) + if createRes != nil { + return *createRes } response := createRoomResponse{ - RoomID: roomID, + RoomID: roomID.String(), RoomAlias: roomAlias, } @@ -599,31 +240,3 @@ func createRoom( JSON: response, } } - -// buildEvent fills out auth_events for the builder then builds the event -func buildEvent( - builder *gomatrixserverlib.EventBuilder, - serverName gomatrixserverlib.ServerName, - provider gomatrixserverlib.AuthEventProvider, - cfg *config.ClientAPI, - evTime time.Time, - roomVersion gomatrixserverlib.RoomVersion, -) (*gomatrixserverlib.Event, error) { - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) - if err != nil { - return nil, err - } - refs, err := eventsNeeded.AuthEventReferences(provider) - if err != nil { - return nil, err - } - builder.AuthEvents = refs - event, err := builder.Build( - evTime, serverName, cfg.Matrix.KeyID, - cfg.Matrix.PrivateKey, roomVersion, - ) - if err != nil { - return nil, fmt.Errorf("cannot build event %s : Builder failed to build. %w", builder.Type, err) - } - return event, nil -} diff --git a/clientapi/routing/deactivate.go b/clientapi/routing/deactivate.go index 030589794f..4a824caee1 100644 --- a/clientapi/routing/deactivate.go +++ b/clientapi/routing/deactivate.go @@ -5,9 +5,9 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -24,7 +24,7 @@ func Deactivate( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), + JSON: spec.BadJSON("The request body could not be read: " + err.Error()), } } var userId string @@ -41,7 +41,10 @@ func Deactivate( localpart, _, err := gomatrixserverlib.SplitID('@', userId) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var res api.PerformAccountDeactivationResponse @@ -50,7 +53,10 @@ func Deactivate( }, &res) if err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.PerformAccountDeactivation failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index e3a02661c4..f57b8957fb 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -15,15 +15,16 @@ package routing import ( + "encoding/json" "io" "net" "net/http" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -59,7 +60,10 @@ func GetDeviceByID( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var targetDevice *api.Device for _, device := range queryRes.Devices { @@ -71,7 +75,7 @@ func GetDeviceByID( if targetDevice == nil { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown device"), + JSON: spec.NotFound("Unknown device"), } } @@ -96,7 +100,10 @@ func GetDevicesByLocalpart( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res := devicesJSON{} @@ -138,18 +145,21 @@ func UpdateDeviceByID( }, &performRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceUpdate failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !performRes.DeviceExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.Forbidden("device does not exist"), + JSON: spec.Forbidden("device does not exist"), } } if performRes.Forbidden { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("device not owned by current user"), + JSON: spec.Forbidden("device not owned by current user"), } } @@ -179,7 +189,7 @@ func DeleteDeviceById( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), + JSON: spec.BadJSON("The request body could not be read: " + err.Error()), } } @@ -189,7 +199,7 @@ func DeleteDeviceById( if dev != deviceID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("session & device mismatch"), + JSON: spec.Forbidden("session and device mismatch"), } } } @@ -211,7 +221,10 @@ func DeleteDeviceById( localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // make sure that the access token being used matches the login creds used for user interactive auth, else @@ -219,7 +232,7 @@ func DeleteDeviceById( if login.Username() != localpart && login.Username() != device.UserID { return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("Cannot delete another user's device"), + JSON: spec.Forbidden("Cannot delete another user's device"), } } @@ -229,7 +242,10 @@ func DeleteDeviceById( DeviceIDs: []string{deviceID}, }, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } deleteOK = true @@ -242,13 +258,39 @@ func DeleteDeviceById( // DeleteDevices handles POST requests to /delete_devices func DeleteDevices( - req *http.Request, userAPI api.ClientUserAPI, device *api.Device, + req *http.Request, userInteractiveAuth *auth.UserInteractive, userAPI api.ClientUserAPI, device *api.Device, ) util.JSONResponse { ctx := req.Context() - payload := devicesDeleteJSON{} - if resErr := httputil.UnmarshalJSONRequest(req, &payload); resErr != nil { - return *resErr + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("The request body could not be read: " + err.Error()), + } + } + defer req.Body.Close() // nolint:errcheck + + // initiate UIA + login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device) + if errRes != nil { + return *errRes + } + + if login.Username() != device.UserID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("unable to delete devices for other user"), + } + } + + payload := devicesDeleteJSON{} + if err = json.Unmarshal(bodyBytes, &payload); err != nil { + util.GetLogger(ctx).WithError(err).Error("unable to unmarshal device deletion request") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } defer req.Body.Close() // nolint: errcheck @@ -259,7 +301,10 @@ func DeleteDevices( DeviceIDs: payload.Devices, }, &res); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.PerformDeviceDeletion failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 696f0c1ef6..d9129d1bd0 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -20,10 +20,10 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -35,7 +35,7 @@ type roomDirectoryResponse struct { Servers []string `json:"servers"` } -func (r *roomDirectoryResponse) fillServers(servers []gomatrixserverlib.ServerName) { +func (r *roomDirectoryResponse) fillServers(servers []spec.ServerName) { r.Servers = make([]string, len(servers)) for i, s := range servers { r.Servers[i] = string(s) @@ -46,7 +46,7 @@ func (r *roomDirectoryResponse) fillServers(servers []gomatrixserverlib.ServerNa func DirectoryRoom( req *http.Request, roomAlias string, - federation *fclient.FederationClient, + federation fclient.FederationClient, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, fedSenderAPI federationAPI.ClientFederationAPI, @@ -55,7 +55,7 @@ func DirectoryRoom( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Room alias must be in the form '#localpart:domain'"), + JSON: spec.BadJSON("Room alias must be in the form '#localpart:domain'"), } } @@ -69,7 +69,10 @@ func DirectoryRoom( queryRes := &roomserverAPI.GetRoomIDForAliasResponse{} if err = rsAPI.GetRoomIDForAlias(req.Context(), queryReq, queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.GetRoomIDForAlias failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res.RoomID = queryRes.RoomID @@ -83,7 +86,10 @@ func DirectoryRoom( // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. util.GetLogger(req.Context()).WithError(fedErr).Error("federation.LookupRoomAlias failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res.RoomID = fedRes.RoomID res.fillServers(fedRes.Servers) @@ -92,7 +98,7 @@ func DirectoryRoom( if res.RoomID == "" { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound( + JSON: spec.NotFound( fmt.Sprintf("Room alias %s not found", roomAlias), ), } @@ -102,7 +108,10 @@ func DirectoryRoom( var joinedHostsRes federationAPI.QueryJoinedHostServerNamesInRoomResponse if err = fedSenderAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &joinedHostsReq, &joinedHostsRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("fedSenderAPI.QueryJoinedHostServerNamesInRoom failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res.fillServers(joinedHostsRes.ServerNames) } @@ -125,14 +134,14 @@ func SetLocalAlias( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Room alias must be in the form '#localpart:domain'"), + JSON: spec.BadJSON("Room alias must be in the form '#localpart:domain'"), } } if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Alias must be on local homeserver"), + JSON: spec.Forbidden("Alias must be on local homeserver"), } } @@ -145,7 +154,7 @@ func SetLocalAlias( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("User ID must be in the form '@localpart:domain'"), + JSON: spec.BadJSON("User ID must be in the form '@localpart:domain'"), } } for _, appservice := range cfg.Derived.ApplicationServices { @@ -157,7 +166,7 @@ func SetLocalAlias( if namespace.Exclusive && namespace.RegexpObject.MatchString(alias) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.ASExclusive("Alias is reserved by an application service"), + JSON: spec.ASExclusive("Alias is reserved by an application service"), } } } @@ -180,13 +189,16 @@ func SetLocalAlias( var queryRes roomserverAPI.SetRoomAliasResponse if err := rsAPI.SetRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if queryRes.AliasExists { return util.JSONResponse{ Code: http.StatusConflict, - JSON: jsonerror.Unknown("The alias " + alias + " already exists."), + JSON: spec.Unknown("The alias " + alias + " already exists."), } } @@ -203,27 +215,63 @@ func RemoveLocalAlias( alias string, rsAPI roomserverAPI.ClientRoomserverAPI, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: "UserID for device is invalid"}, + } + } + + roomIDReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: alias} + roomIDRes := roomserverAPI.GetRoomIDForAliasResponse{} + err = rsAPI.GetRoomIDForAlias(req.Context(), &roomIDReq, &roomIDRes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), + } + } + + validRoomID, err := spec.NewRoomID(roomIDRes.RoomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), + } + } + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), + } + } + queryReq := roomserverAPI.RemoveRoomAliasRequest{ - Alias: alias, - UserID: device.UserID, + Alias: alias, + SenderID: deviceSenderID, } var queryRes roomserverAPI.RemoveRoomAliasResponse if err := rsAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.RemoveRoomAlias failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !queryRes.Found { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The alias does not exist."), + JSON: spec.NotFound("The alias does not exist."), } } if !queryRes.Removed { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You do not have permission to remove this alias."), + JSON: spec.Forbidden("You do not have permission to remove this alias."), } } @@ -248,12 +296,15 @@ func GetVisibility( }, &res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryPublishedRooms failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var v roomVisibility if len(res.RoomIDs) == 1 { - v.Visibility = gomatrixserverlib.Public + v.Visibility = spec.Public } else { v.Visibility = "private" } @@ -270,7 +321,30 @@ func SetVisibility( req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device, roomID string, ) util.JSONResponse { - resErr := checkMemberInRoom(req.Context(), rsAPI, dev.UserID, roomID) + deviceUserID, err := spec.NewUserID(dev.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("userID for this device is invalid"), + } + } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("failed to find senderID for this user"), + } + } + + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } @@ -278,23 +352,26 @@ func SetVisibility( queryEventsReq := roomserverAPI.QueryLatestEventsAndStateRequest{ RoomID: roomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{{ - EventType: gomatrixserverlib.MRoomPowerLevels, + EventType: spec.MRoomPowerLevels, StateKey: "", }}, } var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse - err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) + err = rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) if err != nil || len(queryEventsRes.StateEvents) == 0 { util.GetLogger(req.Context()).WithError(err).Error("could not query events from room") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event - power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].Event) - if power.UserLevel(dev.UserID) < power.EventLevel(gomatrixserverlib.MRoomCanonicalAlias, true) { + power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU) + if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID doesn't have power level to change visibility"), + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), } } @@ -303,16 +380,15 @@ func SetVisibility( return *reqErr } - var publishRes roomserverAPI.PerformPublishResponse - if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ + if err = rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ RoomID: roomID, Visibility: v.Visibility, - }, &publishRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } - if publishRes.Error != nil { - util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed") - return publishRes.Error.JSONResponse() + }); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to publish room") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -328,7 +404,7 @@ func SetVisibilityAS( if dev.AccountType != userapi.AccountTypeAppService { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Only appservice may use this endpoint"), + JSON: spec.Forbidden("Only appservice may use this endpoint"), } } var v roomVisibility @@ -341,18 +417,17 @@ func SetVisibilityAS( return *reqErr } } - var publishRes roomserverAPI.PerformPublishResponse if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ RoomID: roomID, Visibility: v.Visibility, NetworkID: networkID, AppserviceID: dev.AppserviceID, - }, &publishRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } - if publishRes.Error != nil { - util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed") - return publishRes.Error.JSONResponse() + }); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to publish room") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 8e1e05a532..67146630cc 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -23,13 +23,12 @@ import ( "strings" "sync" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" ) @@ -57,7 +56,7 @@ type filter struct { func GetPostPublicRooms( req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, extRoomsProvider api.ExtraPublicRoomsProvider, - federation *fclient.FederationClient, + federation fclient.FederationClient, cfg *config.ClientAPI, ) util.JSONResponse { var request PublicRoomReq @@ -68,11 +67,11 @@ func GetPostPublicRooms( if request.IncludeAllNetworks && request.NetworkID != "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam("include_all_networks and third_party_instance_id can not be used together"), + JSON: spec.InvalidParam("include_all_networks and third_party_instance_id can not be used together"), } } - serverName := gomatrixserverlib.ServerName(request.Server) + serverName := spec.ServerName(request.Server) if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( req.Context(), cfg.Matrix.ServerName, serverName, @@ -82,7 +81,10 @@ func GetPostPublicRooms( ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to get public rooms") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, @@ -93,7 +95,10 @@ func GetPostPublicRooms( response, err := publicRooms(req.Context(), request, rsAPI, extRoomsProvider) if err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to work out public rooms") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, @@ -173,7 +178,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO if httpReq.Method != "GET" && httpReq.Method != "POST" { return &util.JSONResponse{ Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), + JSON: spec.NotFound("Bad method"), } } if httpReq.Method == "GET" { @@ -184,7 +189,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") return &util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON("limit param is not a number"), + JSON: spec.BadJSON("limit param is not a number"), } } request.Limit = int64(limit) diff --git a/clientapi/routing/joined_rooms.go b/clientapi/routing/joined_rooms.go index 4bb353ea99..f664183f87 100644 --- a/clientapi/routing/joined_rooms.go +++ b/clientapi/routing/joined_rooms.go @@ -19,9 +19,9 @@ import ( "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type getJoinedRoomsResponse struct { @@ -40,7 +40,10 @@ func GetJoinedRooms( }, &res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if res.RoomIDs == nil { res.RoomIDs = []string{} diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 3493dd6d88..43331b42ae 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -15,15 +15,17 @@ package routing import ( + "encoding/json" "net/http" "time" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/eventutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -41,7 +43,6 @@ func JoinRoomByIDOrAlias( IsGuest: device.AccountType == api.AccountTypeGuest, Content: map[string]interface{}{}, } - joinRes := roomserverAPI.PerformJoinResponse{} // Check to see if any ?server_name= query parameters were // given in the request. @@ -49,7 +50,7 @@ func JoinRoomByIDOrAlias( for _, serverName := range serverNames { joinReq.ServerNames = append( joinReq.ServerNames, - gomatrixserverlib.ServerName(serverName), + spec.ServerName(serverName), ) } } @@ -72,7 +73,7 @@ func JoinRoomByIDOrAlias( util.GetLogger(req.Context()).Error("Unable to query user profile, no profile found.") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Unable to query user profile, no profile found."), + JSON: spec.Unknown("Unable to query user profile, no profile found."), } default: } @@ -81,37 +82,65 @@ func JoinRoomByIDOrAlias( done := make(chan util.JSONResponse, 1) go func() { defer close(done) - if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { - done <- jsonerror.InternalAPIError(req.Context(), err) - } else if joinRes.Error != nil { - if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest { - done <- util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg), - } - } else { - done <- joinRes.Error.JSONResponse() - } - } else { - done <- util.JSONResponse{ + roomID, _, err := rsAPI.PerformJoin(req.Context(), &joinReq) + var response util.JSONResponse + + switch e := err.(type) { + case nil: // success case + response = util.JSONResponse{ Code: http.StatusOK, // TODO: Put the response struct somewhere internal. JSON: struct { RoomID string `json:"room_id"` - }{joinRes.RoomID}, + }{roomID}, + } + case roomserverAPI.ErrInvalidID: + response = util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown(e.Error()), + } + case roomserverAPI.ErrNotAllowed: + jsonErr := spec.Forbidden(e.Error()) + if device.AccountType == api.AccountTypeGuest { + jsonErr = spec.GuestAccessForbidden(e.Error()) + } + response = util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonErr, + } + case *gomatrix.HTTPError: // this ensures we proxy responses over federation to the client + response = util.JSONResponse{ + Code: e.Code, + JSON: json.RawMessage(e.Message), + } + case eventutil.ErrRoomNoExists: + response = util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(e.Error()), + } + default: + response = util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } + done <- response }() // Wait either for the join to finish, or for us to hit a reasonable // timeout, at which point we'll just return a 200 to placate clients. + timer := time.NewTimer(time.Second * 20) select { - case <-time.After(time.Second * 20): + case <-timer.C: return util.JSONResponse{ Code: http.StatusAccepted, - JSON: jsonerror.Unknown("The room join will continue in the background."), + JSON: spec.Unknown("The room join will continue in the background."), } case result := <-done: + // Stop and drain the timer + if !timer.Stop() { + <-timer.C + } return result } } diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go index fd58ff5d53..0ddff8a95b 100644 --- a/clientapi/routing/joinroom_test.go +++ b/clientapi/routing/joinroom_test.go @@ -11,6 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/roomserver" @@ -63,10 +64,9 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { IsDirect: true, Topic: "testing", Visibility: "public", - Preset: presetPublicChat, + Preset: spec.PresetPublicChat, RoomAliasName: "alias", Invite: []string{bob.ID}, - GuestCanJoin: false, }, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) crResp, ok := resp.JSON.(createRoomResponse) if !ok { @@ -75,13 +75,12 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { // create a room with guest access enabled and invite Charlie resp = createRoom(ctx, createRoomRequest{ - Name: "testing", - IsDirect: true, - Topic: "testing", - Visibility: "public", - Preset: presetPublicChat, - Invite: []string{charlie.ID}, - GuestCanJoin: true, + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: spec.PresetPublicChat, + Invite: []string{charlie.ID}, }, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) if !ok { diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go index b6f8fe1b9f..7f8bd9f403 100644 --- a/clientapi/routing/key_backup.go +++ b/clientapi/routing/key_backup.go @@ -20,8 +20,8 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -61,28 +61,26 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de if resErr != nil { return *resErr } - var performKeyBackupResp userapi.PerformKeyBackupResponse - if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + if len(kb.AuthData) == 0 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("missing auth_data"), + } + } + version, err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, Version: "", AuthData: kb.AuthData, Algorithm: kb.Algorithm, - }, &performKeyBackupResp); err != nil { - return jsonerror.InternalServerError() - } - if performKeyBackupResp.Error != "" { - if performKeyBackupResp.BadInput { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), - } - } - return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + }) + if err != nil { + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", err)) } + return util.JSONResponse{ Code: 200, JSON: keyBackupVersionCreateResponse{ - Version: performKeyBackupResp.Version, + Version: version, }, } } @@ -90,20 +88,17 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de // KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned. // Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { - var queryResp userapi.QueryKeyBackupResponse - if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + queryResp, err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ UserID: device.UserID, Version: version, - }, &queryResp); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } - if queryResp.Error != "" { - return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + }) + if err != nil { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", err)) } if !queryResp.Exists { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("version not found"), + JSON: spec.NotFound("version not found"), } } return util.JSONResponse{ @@ -126,31 +121,29 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse if resErr != nil { return *resErr } - var performKeyBackupResp userapi.PerformKeyBackupResponse - if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + performKeyBackupResp, err := userAPI.UpdateBackupKeyAuthData(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, Version: version, AuthData: kb.AuthData, Algorithm: kb.Algorithm, - }, &performKeyBackupResp); err != nil { - return jsonerror.InternalServerError() - } - if performKeyBackupResp.Error != "" { - if performKeyBackupResp.BadInput { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), - } + }) + switch e := err.(type) { + case spec.ErrRoomKeysVersion: + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: e, } - return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + case nil: + default: + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", e)) } + if !performKeyBackupResp.Exists { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("backup version not found"), + JSON: spec.NotFound("backup version not found"), } } - // Unclear what the 200 body should be return util.JSONResponse{ Code: 200, JSON: keyBackupVersionCreateResponse{ @@ -162,35 +155,19 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.ClientUse // Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK. // Implements DELETE /_matrix/client/r0/room_keys/version/{version} func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { - var performKeyBackupResp userapi.PerformKeyBackupResponse - if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ - UserID: device.UserID, - Version: version, - DeleteBackup: true, - }, &performKeyBackupResp); err != nil { - return jsonerror.InternalServerError() + exists, err := userAPI.DeleteKeyBackup(req.Context(), device.UserID, version) + if err != nil { + return util.ErrorResponse(fmt.Errorf("DeleteKeyBackup: %s", err)) } - if performKeyBackupResp.Error != "" { - if performKeyBackupResp.BadInput { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), - } - } - return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) - } - if !performKeyBackupResp.Exists { + if !exists { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("backup version not found"), + JSON: spec.NotFound("backup version not found"), } } - // Unclear what the 200 body should be return util.JSONResponse{ Code: 200, - JSON: keyBackupVersionCreateResponse{ - Version: performKeyBackupResp.Version, - }, + JSON: struct{}{}, } } @@ -198,27 +175,26 @@ func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de func UploadBackupKeys( req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest, ) util.JSONResponse { - var performKeyBackupResp userapi.PerformKeyBackupResponse - if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + performKeyBackupResp, err := userAPI.UpdateBackupKeyAuthData(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, Version: version, Keys: *keys, - }, &performKeyBackupResp); err != nil && performKeyBackupResp.Error == "" { - return jsonerror.InternalServerError() - } - if performKeyBackupResp.Error != "" { - if performKeyBackupResp.BadInput { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), - } + }) + + switch e := err.(type) { + case spec.ErrRoomKeysVersion: + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: e, } - return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + case nil: + default: + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %w", e)) } if !performKeyBackupResp.Exists { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("backup version not found"), + JSON: spec.NotFound("backup version not found"), } } return util.JSONResponse{ @@ -234,23 +210,20 @@ func UploadBackupKeys( func GetBackupKeys( req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string, ) util.JSONResponse { - var queryResp userapi.QueryKeyBackupResponse - if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + queryResp, err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ UserID: device.UserID, Version: version, ReturnKeys: true, KeysForRoomID: roomID, KeysForSessionID: sessionID, - }, &queryResp); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } - if queryResp.Error != "" { - return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + }) + if err != nil { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %w", err)) } if !queryResp.Exists { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("version not found"), + JSON: spec.NotFound("version not found"), } } if sessionID != "" { @@ -267,17 +240,20 @@ func GetBackupKeys( } } else if roomID != "" { roomData, ok := queryResp.Keys[roomID] - if ok { - // wrap response in "sessions" - return util.JSONResponse{ - Code: 200, - JSON: struct { - Sessions map[string]userapi.KeyBackupSession `json:"sessions"` - }{ - Sessions: roomData, - }, - } + if !ok { + // If no keys are found, then an object with an empty sessions property will be returned + roomData = make(map[string]userapi.KeyBackupSession) + } + // wrap response in "sessions" + return util.JSONResponse{ + Code: 200, + JSON: struct { + Sessions map[string]userapi.KeyBackupSession `json:"sessions"` + }{ + Sessions: roomData, + }, } + } else { // response is the same as the upload request var resp keyBackupSessionRequest @@ -298,6 +274,6 @@ func GetBackupKeys( } return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("keys not found"), + JSON: spec.NotFound("keys not found"), } } diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index dc8a92f1c7..a6c7958c49 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -20,9 +20,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -72,31 +72,29 @@ func UploadCrossSigningDeviceKeys( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID - if err := keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } + keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) if err := uploadRes.Error; err != nil { switch { case err.IsInvalidSignature: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidSignature(err.Error()), + JSON: spec.InvalidSignature(err.Error()), } case err.IsMissingParam: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingParam(err.Error()), + JSON: spec.MissingParam(err.Error()), } case err.IsInvalidParam: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam(err.Error()), + JSON: spec.InvalidParam(err.Error()), } default: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(err.Error()), + JSON: spec.Unknown(err.Error()), } } } @@ -116,31 +114,29 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie } uploadReq.UserID = device.UserID - if err := keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } + keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes) if err := uploadRes.Error; err != nil { switch { case err.IsInvalidSignature: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidSignature(err.Error()), + JSON: spec.InvalidSignature(err.Error()), } case err.IsMissingParam: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingParam(err.Error()), + JSON: spec.MissingParam(err.Error()), } case err.IsInvalidParam: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam(err.Error()), + JSON: spec.InvalidParam(err.Error()), } default: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(err.Error()), + JSON: spec.Unknown(err.Error()), } } } diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 3d60fcc3a6..72785cda8b 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -22,8 +22,8 @@ import ( "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type uploadKeysRequest struct { @@ -67,7 +67,10 @@ func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) } if uploadRes.Error != nil { util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if len(uploadRes.KeyErrors) > 0 { util.GetLogger(req.Context()).WithField("key_errors", uploadRes.KeyErrors).Error("Failed to upload one or more keys") @@ -112,14 +115,12 @@ func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) u return *resErr } queryRes := api.QueryKeysResponse{} - if err := keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ + keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ UserID: device.UserID, UserToDevices: r.DeviceKeys, Timeout: r.GetTimeout(), // TODO: Token? - }, &queryRes); err != nil { - return util.ErrorResponse(err) - } + }, &queryRes) return util.JSONResponse{ Code: 200, JSON: map[string]interface{}{ @@ -152,15 +153,16 @@ func ClaimKeys(req *http.Request, keyAPI api.ClientKeyAPI) util.JSONResponse { return *resErr } claimRes := api.PerformClaimKeysResponse{} - if err := keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{ + keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{ OneTimeKeys: r.OneTimeKeys, Timeout: r.GetTimeout(), - }, &claimRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } + }, &claimRes) if claimRes.Error != nil { util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: 200, diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index a716618517..7e8c066eb2 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -17,9 +17,9 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -29,10 +29,18 @@ func LeaveRoomByID( rsAPI roomserverAPI.ClientRoomserverAPI, roomID string, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("device userID is invalid"), + } + } + // Prepare to ask the roomserver to perform the room join. leaveReq := roomserverAPI.PerformLeaveRequest{ RoomID: roomID, - UserID: device.UserID, + Leaver: *userID, } leaveRes := roomserverAPI.PerformLeaveResponse{} @@ -41,12 +49,12 @@ func LeaveRoomByID( if leaveRes.Code != 0 { return util.JSONResponse{ Code: leaveRes.Code, - JSON: jsonerror.LeaveServerNoticeError(), + JSON: spec.LeaveServerNoticeError(), } } return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(err.Error()), + JSON: spec.Unknown(err.Error()), } } diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 7d0a63e55c..0129150980 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -19,20 +19,19 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) type loginResponse struct { - UserID string `json:"user_id"` - AccessToken string `json:"access_token"` - HomeServer gomatrixserverlib.ServerName `json:"home_server"` - DeviceID string `json:"device_id"` + UserID string `json:"user_id"` + AccessToken string `json:"access_token"` + HomeServer spec.ServerName `json:"home_server"` + DeviceID string `json:"device_id"` } type flows struct { @@ -87,7 +86,7 @@ func Login( } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), + JSON: spec.NotFound("Bad method"), } } @@ -98,13 +97,19 @@ func completeAuth( token, err := auth.GenerateAccessToken() if err != nil { util.GetLogger(ctx).WithError(err).Error("auth.GenerateAccessToken failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } localpart, serverName, err := userutil.ParseUsernameParam(login.Username(), cfg) if err != nil { util.GetLogger(ctx).WithError(err).Error("auth.ParseUsernameParam failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var performRes userapi.PerformDeviceCreationResponse @@ -120,7 +125,7 @@ func completeAuth( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create device: " + err.Error()), + JSON: spec.Unknown("failed to create device: " + err.Error()), } } diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index 73bae7af73..d06bac7845 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -17,8 +17,8 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -33,7 +33,10 @@ func Logout( }, &performRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -53,7 +56,10 @@ func LogoutAll( }, &performRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 1a96d4b1dd..60b120b9ce 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -16,22 +16,25 @@ package routing import ( "context" + "crypto/ed25519" "fmt" "net/http" "time" - "github.com/matrix-org/gomatrixserverlib" - + "github.com/getsentry/sentry-go" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -49,11 +52,33 @@ func SendBan( if body.UserID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing user_id"), + JSON: spec.BadJSON("missing user_id"), + } + } + + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), + } + } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"), } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -62,15 +87,15 @@ func SendBan( if errRes != nil { return *errRes } - allowedToBan := pl.UserLevel(device.UserID) >= pl.Ban + allowedToBan := pl.UserLevel(senderID) >= pl.Ban if !allowedToBan { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You don't have permission to ban this user, power level too low."), + JSON: spec.Forbidden("You don't have permission to ban this user, power level too low."), } } - return sendMembership(req.Context(), profileAPI, device, roomID, gomatrixserverlib.Ban, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) + return sendMembership(req.Context(), profileAPI, device, roomID, spec.Ban, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) } func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, device *userapi.Device, @@ -83,14 +108,17 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic ) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } serverName := device.UserDomain() if err = roomserverAPI.SendEvents( ctx, rsAPI, roomserverAPI.KindNew, - []*gomatrixserverlib.HeaderedEvent{event}, + []*types.HeaderedEvent{event}, device.UserDomain(), serverName, serverName, @@ -98,7 +126,10 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic false, ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -119,11 +150,33 @@ func SendKick( if body.UserID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing user_id"), + JSON: spec.BadJSON("missing user_id"), + } + } + + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -132,31 +185,38 @@ func SendKick( if errRes != nil { return *errRes } - allowedToKick := pl.UserLevel(device.UserID) >= pl.Kick + allowedToKick := pl.UserLevel(senderID) >= pl.Kick if !allowedToKick { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You don't have permission to kick this user, power level too low."), + JSON: spec.Forbidden("You don't have permission to kick this user, power level too low."), } } + bodyUserID, err := spec.NewUserID(body.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("body userID is invalid"), + } + } var queryRes roomserverAPI.QueryMembershipForUserResponse - err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ + err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: body.UserID, + UserID: *bodyUserID, }, &queryRes) if err != nil { return util.ErrorResponse(err) } // kick is only valid if the user is not currently banned or left (that is, they are joined or invited) - if queryRes.Membership != gomatrixserverlib.Join && queryRes.Membership != gomatrixserverlib.Invite { + if queryRes.Membership != spec.Join && queryRes.Membership != spec.Invite { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Unknown("cannot /kick banned or left users"), + JSON: spec.Unknown("cannot /kick banned or left users"), } } // TODO: should we be using SendLeave instead? - return sendMembership(req.Context(), profileAPI, device, roomID, gomatrixserverlib.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) + return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) } func SendUnban( @@ -171,33 +231,48 @@ func SendUnban( if body.UserID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing user_id"), + JSON: spec.BadJSON("missing user_id"), + } + } + + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } + bodyUserID, err := spec.NewUserID(body.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("body userID is invalid"), + } + } var queryRes roomserverAPI.QueryMembershipForUserResponse - err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ + err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: body.UserID, + UserID: *bodyUserID, }, &queryRes) if err != nil { return util.ErrorResponse(err) } // unban is only valid if the user is currently banned - if queryRes.Membership != gomatrixserverlib.Ban { + if queryRes.Membership != spec.Ban { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("can only /unban users that are banned"), + JSON: spec.Unknown("can only /unban users that are banned"), } } // TODO: should we be using SendLeave instead? - return sendMembership(req.Context(), profileAPI, device, roomID, gomatrixserverlib.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) + return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) } func SendInvite( @@ -230,11 +305,19 @@ func SendInvite( if body.UserID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing user_id"), + JSON: spec.BadJSON("missing user_id"), } } - errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + + errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if errRes != nil { return *errRes } @@ -255,30 +338,44 @@ func sendInvite( asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, ) (util.JSONResponse, error) { event, err := buildMembershipEvent( - ctx, userID, reason, profileAPI, device, gomatrixserverlib.Invite, + ctx, userID, reason, profileAPI, device, spec.Invite, roomID, false, cfg, evTime, rsAPI, asAPI, ) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") - return jsonerror.InternalServerError(), err + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + }, err } - var inviteRes api.PerformInviteResponse - if err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{ + err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{ Event: event, InviteRoomState: nil, // ask the roomserver to draw up invite room state for us - RoomVersion: event.RoomVersion, + RoomVersion: event.Version(), SendAsServer: string(device.UserDomain()), - }, &inviteRes); err != nil { + }) + + switch e := err.(type) { + case roomserverAPI.ErrInvalidID: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown(e.Error()), + }, e + case roomserverAPI.ErrNotAllowed: + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden(e.Error()), + }, e + case nil: + default: util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") + sentry.CaptureException(err) return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError{}, }, err } - if inviteRes.Error != nil { - return inviteRes.Error.JSONResponse(), inviteRes.Error - } return util.JSONResponse{ Code: http.StatusOK, @@ -286,6 +383,42 @@ func sendInvite( }, nil } +func buildMembershipEventDirect( + ctx context.Context, + targetSenderID spec.SenderID, reason string, userDisplayName, userAvatarURL string, + sender spec.SenderID, senderDomain spec.ServerName, + membership, roomID string, isDirect bool, + keyID gomatrixserverlib.KeyID, privateKey ed25519.PrivateKey, evTime time.Time, + rsAPI roomserverAPI.ClientRoomserverAPI, +) (*types.HeaderedEvent, error) { + targetSenderString := string(targetSenderID) + proto := gomatrixserverlib.ProtoEvent{ + SenderID: string(sender), + RoomID: roomID, + Type: "m.room.member", + StateKey: &targetSenderString, + } + + content := gomatrixserverlib.MemberContent{ + Membership: membership, + DisplayName: userDisplayName, + AvatarURL: userAvatarURL, + Reason: reason, + IsDirect: isDirect, + } + + if err := proto.SetContent(content); err != nil { + return nil, err + } + + identity := &fclient.SigningIdentity{ + ServerName: senderDomain, + KeyID: keyID, + PrivateKey: privateKey, + } + return eventutil.QueryAndBuildEvent(ctx, &proto, identity, evTime, rsAPI, nil) +} + func buildMembershipEvent( ctx context.Context, targetUserID, reason string, profileAPI userapi.ClientUserAPI, @@ -293,37 +426,41 @@ func buildMembershipEvent( membership, roomID string, isDirect bool, cfg *config.ClientAPI, evTime time.Time, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, -) (*gomatrixserverlib.HeaderedEvent, error) { +) (*types.HeaderedEvent, error) { profile, err := loadProfile(ctx, targetUserID, cfg, profileAPI, asAPI) if err != nil { return nil, err } - builder := gomatrixserverlib.EventBuilder{ - Sender: device.UserID, - RoomID: roomID, - Type: "m.room.member", - StateKey: &targetUserID, + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return nil, err } - - content := gomatrixserverlib.MemberContent{ - Membership: membership, - DisplayName: profile.DisplayName, - AvatarURL: profile.AvatarURL, - Reason: reason, - IsDirect: isDirect, + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID) + if err != nil { + return nil, err } - if err = builder.SetContent(content); err != nil { + targetID, err := spec.NewUserID(targetUserID, true) + if err != nil { + return nil, err + } + targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *targetID) + if err != nil { return nil, err } - identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *userID) if err != nil { return nil, err } - return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil) + return buildMembershipEventDirect(ctx, targetSenderID, reason, profile.DisplayName, profile.AvatarURL, + senderID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI) } // loadProfile lookups the profile of a given user from the database and returns @@ -363,7 +500,7 @@ func extractRequestData(req *http.Request) (body *threepid.MembershipRequest, ev if err != nil { resErr = &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } return } @@ -385,36 +522,43 @@ func checkAndProcessThreepid( req.Context(), device, body, cfg, rsAPI, profileAPI, roomID, evTime, ) - if err == threepid.ErrMissingParameter { + switch e := err.(type) { + case nil: + case threepid.ErrMissingParameter: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") return inviteStored, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } - } else if err == threepid.ErrNotTrusted { + case threepid.ErrNotTrusted: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") return inviteStored, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotTrusted(body.IDServer), + JSON: spec.NotTrusted(body.IDServer), } - } else if err == eventutil.ErrRoomNoExists { + case eventutil.ErrRoomNoExists: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") return inviteStored, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(err.Error()), + JSON: spec.NotFound(err.Error()), } - } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + case gomatrixserverlib.BadJSONError: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") return inviteStored, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(e.Error()), } - } - if err != nil { + default: util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAndProcessInvite failed") - er := jsonerror.InternalServerError() - return inviteStored, &er + return inviteStored, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return } -func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID, roomID string) *util.JSONResponse { +func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID spec.UserID, roomID string) *util.JSONResponse { var membershipRes roomserverAPI.QueryMembershipForUserResponse err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, @@ -422,13 +566,15 @@ func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserver }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryMembershipForUser: could not query membership for user") - e := jsonerror.InternalServerError() - return &e + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !membershipRes.IsInRoom { return &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("user does not belong to room"), + JSON: spec.Forbidden("user does not belong to room"), } } return nil @@ -440,26 +586,38 @@ func SendForget( ) util.JSONResponse { ctx := req.Context() logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) + + deviceUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), + } + } + var membershipRes roomserverAPI.QueryMembershipForUserResponse membershipReq := roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *deviceUserID, } - err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) + err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) if err != nil { logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !membershipRes.RoomExists { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("room does not exist"), + JSON: spec.Forbidden("room does not exist"), } } if membershipRes.IsInRoom { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(fmt.Sprintf("User %s is in room %s", device.UserID, roomID)), + JSON: spec.Unknown(fmt.Sprintf("User %s is in room %s", device.UserID, roomID)), } } @@ -470,7 +628,10 @@ func SendForget( response := roomserverAPI.PerformForgetResponse{} if err := rsAPI.PerformForget(ctx, &request, &response); err != nil { logger.WithError(err).Error("PerformForget: unable to forget room") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, @@ -480,20 +641,20 @@ func SendForget( func getPowerlevels(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, roomID string) (*gomatrixserverlib.PowerLevelContent, *util.JSONResponse) { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, + EventType: spec.MRoomPowerLevels, StateKey: "", }) if plEvent == nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You don't have permission to perform this action, no power_levels event in this room."), + JSON: spec.Forbidden("You don't have permission to perform this action, no power_levels event in this room."), } } pl, err := plEvent.PowerLevels() if err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You don't have permission to perform this action, the power_levels event for this room is malformed so auth checks cannot be performed."), + JSON: spec.Forbidden("You don't have permission to perform this action, the power_levels event for this room is malformed so auth checks cannot be performed."), } } return pl, nil diff --git a/clientapi/routing/multiroom.go b/clientapi/routing/multiroom.go index 14d3c29b69..4d8d54a112 100644 --- a/clientapi/routing/multiroom.go +++ b/clientapi/routing/multiroom.go @@ -4,10 +4,10 @@ import ( "io" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -23,14 +23,14 @@ func PostMultiroom( log.WithError(err).Errorf("failed to read request body") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError{}, } } canonicalB, err := gomatrixserverlib.CanonicalJSON(b) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body is not valid canonical JSON." + err.Error()), + JSON: spec.BadJSON("The request body is not valid canonical JSON." + err.Error()), } } err = producer.SendMultiroom(req.Context(), device.UserID, dataType, canonicalB) @@ -38,7 +38,7 @@ func PostMultiroom( log.WithError(err).Errorf("failed to send multiroomcast") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError{}, } } return util.JSONResponse{ diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go index f593e27db3..4b9043faae 100644 --- a/clientapi/routing/notification.go +++ b/clientapi/routing/notification.go @@ -18,9 +18,9 @@ import ( "net/http" "strconv" - "github.com/matrix-org/dendrite/clientapi/jsonerror" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -35,7 +35,10 @@ func GetNotifications( limit, err = strconv.ParseInt(limitStr, 10, 64) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } @@ -43,7 +46,10 @@ func GetNotifications( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ Localpart: localpart, @@ -54,7 +60,10 @@ func GetNotifications( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications)) return util.JSONResponse{ diff --git a/clientapi/routing/openid.go b/clientapi/routing/openid.go index 8e9be78890..8dfba8af98 100644 --- a/clientapi/routing/openid.go +++ b/clientapi/routing/openid.go @@ -17,9 +17,9 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -43,7 +43,7 @@ func CreateOpenIDToken( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot request tokens for other users"), + JSON: spec.Forbidden("Cannot request tokens for other users"), } } @@ -55,7 +55,10 @@ func CreateOpenIDToken( err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.CreateOpenIDToken failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index b2293c399e..bd0f176d00 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -7,12 +7,12 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -62,7 +62,7 @@ func Password( sessionID = util.RandomString(sessionIDLength) } var localpart string - var domain gomatrixserverlib.ServerName + var domain spec.ServerName switch r.Auth.Type { case authtypes.LoginTypePassword: // Check if the existing password is correct. @@ -78,7 +78,10 @@ func Password( localpart, domain, err = gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) case authtypes.LoginTypeEmail: @@ -91,12 +94,15 @@ func Password( bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg, nil) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !bound { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MatrixError{ + JSON: spec.MatrixError{ ErrCode: "M_THREEPID_AUTH_FAILED", Err: "Failed to auth 3pid", }, @@ -109,12 +115,15 @@ func Password( }, &res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryLocalpartForThreePID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if res.Localpart == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MatrixError{ + JSON: spec.MatrixError{ ErrCode: "M_THREEPID_NOT_FOUND", Err: "3pid is not bound to any account", }, @@ -161,11 +170,17 @@ func Password( passwordRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !passwordRes.PasswordUpdated { util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // If the request asks us to log out all other devices then @@ -191,7 +206,10 @@ func Password( logoutRes := &api.PerformDeviceDeletionResponse{} if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } pushersReq := &api.PerformPusherDeletionRequest{ @@ -201,7 +219,10 @@ func Password( } if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index 9b2592eb58..772dc8477f 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -15,13 +15,15 @@ package routing import ( + "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) func PeekRoomByIDOrAlias( @@ -41,25 +43,42 @@ func PeekRoomByIDOrAlias( UserID: device.UserID, DeviceID: device.ID, } - peekRes := roomserverAPI.PerformPeekResponse{} - // Check to see if any ?server_name= query parameters were // given in the request. if serverNames, ok := req.URL.Query()["server_name"]; ok { for _, serverName := range serverNames { peekReq.ServerNames = append( peekReq.ServerNames, - gomatrixserverlib.ServerName(serverName), + spec.ServerName(serverName), ) } } // Ask the roomserver to perform the peek. - if err := rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes); err != nil { - return util.ErrorResponse(err) - } - if peekRes.Error != nil { - return peekRes.Error.JSONResponse() + roomID, err := rsAPI.PerformPeek(req.Context(), &peekReq) + switch e := err.(type) { + case roomserverAPI.ErrInvalidID: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown(e.Error()), + } + case roomserverAPI.ErrNotAllowed: + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden(e.Error()), + } + case *gomatrix.HTTPError: + return util.JSONResponse{ + Code: e.Code, + JSON: json.RawMessage(e.Message), + } + case nil: + default: + logrus.WithError(err).WithField("roomID", roomIDOrAlias).Errorf("Failed to peek room") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // if this user is already joined to the room, we let them peek anyway @@ -75,7 +94,7 @@ func PeekRoomByIDOrAlias( // TODO: Put the response struct somewhere internal. JSON: struct { RoomID string `json:"room_id"` - }{peekRes.RoomID}, + }{roomID}, } } @@ -85,18 +104,20 @@ func UnpeekRoomByID( rsAPI roomserverAPI.ClientRoomserverAPI, roomID string, ) util.JSONResponse { - unpeekReq := roomserverAPI.PerformUnpeekRequest{ - RoomID: roomID, - UserID: device.UserID, - DeviceID: device.ID, - } - unpeekRes := roomserverAPI.PerformUnpeekResponse{} - - if err := rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } - if unpeekRes.Error != nil { - return unpeekRes.Error.JSONResponse() + err := rsAPI.PerformUnpeek(req.Context(), roomID, device.UserID, device.ID) + switch e := err.(type) { + case roomserverAPI.ErrInvalidID: + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown(e.Error()), + } + case nil: + default: + logrus.WithError(err).WithField("roomID", roomID).Errorf("Failed to un-peek room") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/presence.go b/clientapi/routing/presence.go index 093a62464c..5aa6d8dd29 100644 --- a/clientapi/routing/presence.go +++ b/clientapi/routing/presence.go @@ -21,13 +21,12 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -54,7 +53,7 @@ func SetPresence( if device.UserID != userID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Unable to set presence for other user."), + JSON: spec.Forbidden("Unable to set presence for other user."), } } var presence presenceReq @@ -67,7 +66,7 @@ func SetPresence( if !ok { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(fmt.Sprintf("Unknown presence '%s'.", presence.Presence)), + JSON: spec.Unknown(fmt.Sprintf("Unknown presence '%s'.", presence.Presence)), } } err := producer.SendPresence(req.Context(), userID, presenceStatus, presence.StatusMsg) @@ -75,7 +74,7 @@ func SetPresence( log.WithError(err).Errorf("failed to update presence") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError{}, } } @@ -100,7 +99,7 @@ func GetPresence( log.WithError(err).Errorf("unable to get presence") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError{}, } } @@ -119,11 +118,11 @@ func GetPresence( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError{}, } } - p := types.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(lastActive)} + p := types.PresenceInternal{LastActiveTS: spec.Timestamp(lastActive)} currentlyActive := p.CurrentlyActive() return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 2589c88979..35da15e0e5 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -21,16 +21,16 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" ) @@ -40,19 +40,22 @@ func GetProfile( req *http.Request, profileAPI userapi.ProfileAPI, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceInternalAPI, - federation *fclient.FederationClient, + federation fclient.FederationClient, ) util.JSONResponse { profile, err := getProfile(req.Context(), profileAPI, cfg, userID, asAPI, federation) if err != nil { if err == appserviceAPI.ErrProfileNotExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), + JSON: spec.NotFound("The user does not exist or does not have a profile"), } } util.GetLogger(req.Context()).WithError(err).Error("getProfile failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -68,7 +71,7 @@ func GetProfile( func GetAvatarURL( req *http.Request, profileAPI userapi.ProfileAPI, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceInternalAPI, - federation *fclient.FederationClient, + federation fclient.FederationClient, ) util.JSONResponse { profile := GetProfile(req, profileAPI, cfg, userID, asAPI, federation) p, ok := profile.JSON.(eventutil.UserProfile) @@ -93,7 +96,7 @@ func SetAvatarURL( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not match the current user"), + JSON: spec.Forbidden("userID does not match the current user"), } } @@ -105,13 +108,16 @@ func SetAvatarURL( localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"), + JSON: spec.Forbidden("userID does not belong to a locally configured domain"), } } @@ -119,14 +125,17 @@ func SetAvatarURL( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } profile, changed, err := profileAPI.SetAvatarURL(req.Context(), localpart, domain, r.AvatarURL) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // No need to build new membership events, since nothing changed if !changed { @@ -136,7 +145,7 @@ func SetAvatarURL( } } - response, err := updateProfile(req.Context(), rsAPI, device, profile, userID, cfg, evTime) + response, err := updateProfile(req.Context(), rsAPI, device, profile, userID, evTime) if err != nil { return response } @@ -151,7 +160,7 @@ func SetAvatarURL( func GetDisplayName( req *http.Request, profileAPI userapi.ProfileAPI, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceInternalAPI, - federation *fclient.FederationClient, + federation fclient.FederationClient, ) util.JSONResponse { profile := GetProfile(req, profileAPI, cfg, userID, asAPI, federation) p, ok := profile.JSON.(eventutil.UserProfile) @@ -176,7 +185,7 @@ func SetDisplayName( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not match the current user"), + JSON: spec.Forbidden("userID does not match the current user"), } } @@ -184,23 +193,20 @@ func SetDisplayName( if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr } - if r.DisplayName == "" { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("'displayname' must be supplied."), - } - } localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"), + JSON: spec.Forbidden("userID does not belong to a locally configured domain"), } } @@ -208,14 +214,17 @@ func SetDisplayName( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } profile, changed, err := profileAPI.SetDisplayName(req.Context(), localpart, domain, r.DisplayName) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // No need to build new membership events, since nothing changed if !changed { @@ -225,7 +234,7 @@ func SetDisplayName( } } - response, err := updateProfile(req.Context(), rsAPI, device, profile, userID, cfg, evTime) + response, err := updateProfile(req.Context(), rsAPI, device, profile, userID, evTime) if err != nil { return response } @@ -239,7 +248,7 @@ func SetDisplayName( func updateProfile( ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device, profile *authtypes.Profile, - userID string, cfg *config.ClientAPI, evTime time.Time, + userID string, evTime time.Time, ) (util.JSONResponse, error) { var res api.QueryRoomsForUserResponse err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ @@ -248,33 +257,45 @@ func updateProfile( }, &res) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError(), err + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + }, err } _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError(), err + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + }, err } events, err := buildMembershipEvents( - ctx, device, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, + ctx, res.RoomIDs, *profile, userID, evTime, rsAPI, ) switch e := err.(type) { case nil: case gomatrixserverlib.BadJSONError: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(e.Error()), }, e default: util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed") - return jsonerror.InternalServerError(), e + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + }, e } if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError(), err + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + }, err } return util.JSONResponse{}, nil } @@ -287,7 +308,7 @@ func getProfile( ctx context.Context, profileAPI userapi.ProfileAPI, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceInternalAPI, - federation *fclient.FederationClient, + federation fclient.FederationClient, ) (*authtypes.Profile, error) { localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { @@ -323,49 +344,60 @@ func getProfile( func buildMembershipEvents( ctx context.Context, - device *userapi.Device, roomIDs []string, - newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, + newProfile authtypes.Profile, userID string, evTime time.Time, rsAPI api.ClientRoomserverAPI, -) ([]*gomatrixserverlib.HeaderedEvent, error) { - evs := []*gomatrixserverlib.HeaderedEvent{} +) ([]*types.HeaderedEvent, error) { + evs := []*types.HeaderedEvent{} + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return nil, err + } for _, roomID := range roomIDs { - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { return nil, err } - - builder := gomatrixserverlib.EventBuilder{ - Sender: userID, + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) + if err != nil { + return nil, err + } + senderIDString := string(senderID) + proto := gomatrixserverlib.ProtoEvent{ + SenderID: senderIDString, RoomID: roomID, Type: "m.room.member", - StateKey: &userID, + StateKey: &senderIDString, } content := gomatrixserverlib.MemberContent{ - Membership: gomatrixserverlib.Join, + Membership: spec.Join, } content.DisplayName = newProfile.DisplayName content.AvatarURL = newProfile.AvatarURL - if err := builder.SetContent(content); err != nil { + if err = proto.SetContent(content); err != nil { + return nil, err + } + + user, err := spec.NewUserID(userID, true) + if err != nil { return nil, err } - identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *user) if err != nil { return nil, err } - event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil) + event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, evTime, rsAPI, nil) if err != nil { return nil, err } - evs = append(evs, event.Headered(verRes.RoomVersion)) + evs = append(evs, event) } return evs, nil diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go index 548423c3ce..a74b4cad68 100644 --- a/clientapi/routing/pusher.go +++ b/clientapi/routing/pusher.go @@ -19,9 +19,9 @@ import ( "net/url" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -34,7 +34,10 @@ func GetPushers( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ Localpart: localpart, @@ -42,7 +45,10 @@ func GetPushers( }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } for i := range queryRes.Pushers { queryRes.Pushers[i].SessionID = 0 @@ -63,7 +69,10 @@ func SetPusher( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } body := userapi.PerformPusherSetRequest{} if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil { @@ -99,7 +108,10 @@ func SetPusher( err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -111,6 +123,6 @@ func SetPusher( func invalidParam(msg string) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam(msg), + JSON: spec.InvalidParam(msg), } } diff --git a/clientapi/routing/pushrules.go b/clientapi/routing/pushrules.go index 856f52c755..74873d5c91 100644 --- a/clientapi/routing/pushrules.go +++ b/clientapi/routing/pushrules.go @@ -7,31 +7,34 @@ import ( "net/http" "reflect" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/pushrules" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse { - if eerr, ok := err.(*jsonerror.MatrixError); ok { + if eerr, ok := err.(spec.MatrixError); ok { var status int switch eerr.ErrCode { - case "M_INVALID_ARGUMENT_VALUE": + case spec.ErrorInvalidParam: status = http.StatusBadRequest - case "M_NOT_FOUND": + case spec.ErrorNotFound: status = http.StatusNotFound default: status = http.StatusInternalServerError } - return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err) + return util.MatrixErrorResponse(status, string(eerr.ErrCode), eerr.Err) } util.GetLogger(ctx).WithError(err).Errorf(msg, args...) - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRulesJSON failed") } @@ -42,13 +45,13 @@ func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userap } func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRulesJSON failed") } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } return util.JSONResponse{ Code: http.StatusOK, @@ -57,17 +60,18 @@ func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Devi } func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) - if rulesPtr == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + // Even if rulesPtr is not nil, there may not be any rules for this kind + if rulesPtr == nil || (rulesPtr != nil && len(*rulesPtr) == 0) { + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } return util.JSONResponse{ Code: http.StatusOK, @@ -76,21 +80,21 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi } func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) if i < 0 { - return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + return errorResponse(ctx, spec.NotFound("push rule ID not found"), "pushRuleIndexByID failed") } return util.JSONResponse{ Code: http.StatusOK, @@ -101,26 +105,30 @@ func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { var newRule pushrules.Rule if err := json.NewDecoder(body).Decode(&newRule); err != nil { - return errorResponse(ctx, err, "JSON Decode failed") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON(err.Error()), + } } newRule.RuleID = ruleID errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule) if len(errs) > 0 { - return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs) + return errorResponse(ctx, spec.InvalidParam(errs[0].Error()), "rule sanity check failed: %v", errs) } - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + // while this should be impossible (ValidateRule would already return an error), better keep it around + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) if i >= 0 && afterRuleID == "" && beforeRuleID == "" { @@ -144,7 +152,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, } // Add new rule. - i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID) + i, err = findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID) if err != nil { return errorResponse(ctx, err, "findPushRuleInsertionIndex failed") } @@ -153,7 +161,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i) } - if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil { return errorResponse(ctx, err, "putPushRules failed") } @@ -161,26 +169,26 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, } func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) if i < 0 { - return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + return errorResponse(ctx, spec.NotFound("push rule ID not found"), "pushRuleIndexByID failed") } *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...) - if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil { return errorResponse(ctx, err, "putPushRules failed") } @@ -192,21 +200,21 @@ func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri if err != nil { return errorResponse(ctx, err, "pushRuleAttrGetter failed") } - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) if i < 0 { - return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + return errorResponse(ctx, spec.NotFound("push rule ID not found"), "pushRuleIndexByID failed") } return util.JSONResponse{ Code: http.StatusOK, @@ -221,7 +229,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } } if newPartialRule.Actions == nil { @@ -238,27 +246,27 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri return errorResponse(ctx, err, "pushRuleAttrSetter failed") } - ruleSets, err := queryPushRules(ctx, device.UserID, userAPI) + ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID) if err != nil { return errorResponse(ctx, err, "queryPushRules failed") } ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope)) if ruleSet == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rule set"), "pushRuleSetByScope failed") } rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind)) if rulesPtr == nil { - return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed") + return errorResponse(ctx, spec.InvalidParam("invalid push rules kind"), "pushRuleSetKindPointer failed") } i := pushRuleIndexByID(*rulesPtr, ruleID) if i < 0 { - return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed") + return errorResponse(ctx, spec.NotFound("push rule ID not found"), "pushRuleIndexByID failed") } if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) { attrSet((*rulesPtr)[i], &newPartialRule) - if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil { + if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil { return errorResponse(ctx, err, "putPushRules failed") } } @@ -266,28 +274,6 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} } -func queryPushRules(ctx context.Context, userID string, userAPI userapi.ClientUserAPI) (*pushrules.AccountRuleSets, error) { - var res userapi.QueryPushRulesResponse - if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil { - util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed") - return nil, err - } - return res.RuleSets, nil -} - -func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.ClientUserAPI) error { - req := userapi.PerformPushRulesPutRequest{ - UserID: userID, - RuleSets: ruleSets, - } - var res struct{} - if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil { - util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed") - return err - } - return nil -} - func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet { switch scope { case pushrules.GlobalScope: @@ -330,7 +316,7 @@ func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) case "enabled": return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil default: - return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute") + return nil, spec.InvalidParam("invalid push rule attribute") } } @@ -341,7 +327,7 @@ func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) { case "enabled": return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil default: - return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute") + return nil, spec.InvalidParam("invalid push rule attribute") } } @@ -355,10 +341,10 @@ func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID strin } } if i == len(rules) { - return 0, jsonerror.NotFound("after: rule ID not found") + return 0, spec.NotFound("after: rule ID not found") } if rules[i].Default { - return 0, jsonerror.NotFound("after: rule ID must not be a default rule") + return 0, spec.NotFound("after: rule ID must not be a default rule") } // We stopped on the "after" match to differentiate // not-found from is-last-entry. Now we move to the earliest @@ -373,10 +359,10 @@ func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID strin } } if i == len(rules) { - return 0, jsonerror.NotFound("before: rule ID not found") + return 0, spec.NotFound("before: rule ID not found") } if rules[i].Default { - return 0, jsonerror.NotFound("before: rule ID must not be a default rule") + return 0, spec.NotFound("before: rule ID must not be a default rule") } } diff --git a/clientapi/routing/receipt.go b/clientapi/routing/receipt.go index 99217a7802..be6542979f 100644 --- a/clientapi/routing/receipt.go +++ b/clientapi/routing/receipt.go @@ -20,9 +20,8 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -31,7 +30,7 @@ import ( ) func SetReceipt(req *http.Request, userAPI api.ClientUserAPI, syncProducer *producers.SyncAPIProducer, device *userapi.Device, roomID, receiptType, eventID string) util.JSONResponse { - timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + timestamp := spec.AsTimestamp(time.Now()) logrus.WithFields(logrus.Fields{ "roomID": roomID, "receiptType": receiptType, @@ -49,7 +48,10 @@ func SetReceipt(req *http.Request, userAPI api.ClientUserAPI, syncProducer *prod case "m.fully_read": data, err := json.Marshal(fullyReadEvent{EventID: eventID}) if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } dataReq := api.InputAccountDataRequest{ diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index f86bbc8fd9..1b9a5a8188 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -16,17 +16,19 @@ package routing import ( "context" + "errors" "net/http" "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/transactions" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" ) @@ -45,7 +47,29 @@ func SendRedaction( txnID *string, txnCache *transactions.Cache, ) util.JSONResponse { - resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) + deviceUserID, userIDErr := spec.NewUserID(device.UserID, true) + if userIDErr != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to redact"), + } + } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) + if queryErr != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to redact"), + } + } + + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } @@ -61,46 +85,46 @@ func SendRedaction( if ev == nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.NotFound("unknown event ID"), // TODO: is it ok to leak existence? + JSON: spec.NotFound("unknown event ID"), // TODO: is it ok to leak existence? } } if ev.RoomID() != roomID { return util.JSONResponse{ Code: 400, - JSON: jsonerror.NotFound("cannot redact event in another room"), + JSON: spec.NotFound("cannot redact event in another room"), } } // "Users may redact their own events, and any user with a power level greater than or equal // to the redact power level of the room may redact events there" // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid - allowedToRedact := ev.Sender() == device.UserID + allowedToRedact := ev.SenderID() == senderID if !allowedToRedact { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, + EventType: spec.MRoomPowerLevels, StateKey: "", }) if plEvent == nil { return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("You don't have permission to redact this event, no power_levels event in this room."), + JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."), } } - pl, err := plEvent.PowerLevels() - if err != nil { + pl, plErr := plEvent.PowerLevels() + if plErr != nil { return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden( + JSON: spec.Forbidden( "You don't have permission to redact this event, the power_levels event for this room is malformed so auth checks cannot be performed.", ), } } - allowedToRedact = pl.UserLevel(device.UserID) >= pl.Redact + allowedToRedact = pl.UserLevel(senderID) >= pl.Redact } if !allowedToRedact { return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("You don't have permission to redact this event, power level too low."), + JSON: spec.Forbidden("You don't have permission to redact this event, power level too low."), } } @@ -111,35 +135,44 @@ func SendRedaction( } // create the new event and set all the fields we can - builder := gomatrixserverlib.EventBuilder{ - Sender: device.UserID, - RoomID: roomID, - Type: gomatrixserverlib.MRoomRedaction, - Redacts: eventID, + proto := gomatrixserverlib.ProtoEvent{ + SenderID: string(senderID), + RoomID: roomID, + Type: spec.MRoomRedaction, + Redacts: eventID, } - err := builder.SetContent(r) + err = proto.SetContent(r) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") - return jsonerror.InternalServerError() + util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + identity, err := rsAPI.SigningIdentityFor(req.Context(), *validRoomID, *deviceUserID) if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var queryRes roomserverAPI.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) - if err == eventutil.ErrRoomNoExists { + e, err := eventutil.QueryAndBuildEvent(req.Context(), &proto, &identity, time.Now(), rsAPI, &queryRes) + if errors.Is(err, eventutil.ErrRoomNoExists{}) { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + JSON: spec.NotFound("Room does not exist"), } } domain := device.UserDomain() - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*types.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res := util.JSONResponse{ diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 4102a5a7df..2921c9cb63 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -37,6 +37,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -45,7 +46,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/clientapi/userutil" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -165,7 +165,7 @@ func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtype return } } - d.sessions[sessionID] = append(sessions.sessions[sessionID], stage) + d.sessions[sessionID] = append(d.sessions[sessionID], stage) } func (d *sessionsDict) addDeviceToDelete(sessionID, deviceID string) { @@ -207,10 +207,10 @@ var ( // previous parameters with the ones supplied. This mean you cannot "build up" request params. type registerRequest struct { // registration parameters - Password string `json:"password"` - Username string `json:"username"` - ServerName gomatrixserverlib.ServerName `json:"-"` - Admin bool `json:"admin"` + Password string `json:"password"` + Username string `json:"username"` + ServerName spec.ServerName `json:"-"` + Admin bool `json:"admin"` // user-interactive auth params Auth authDict `json:"auth"` @@ -429,7 +429,7 @@ func validateApplicationService( if matchedApplicationService == nil { return "", &util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.UnknownToken("Supplied access_token does not match any known application service"), + JSON: spec.UnknownToken("Supplied access_token does not match any known application service"), } } @@ -440,7 +440,7 @@ func validateApplicationService( // If we didn't find any matches, return M_EXCLUSIVE return "", &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.ASExclusive(fmt.Sprintf( + JSON: spec.ASExclusive(fmt.Sprintf( "Supplied username %s did not match any namespaces for application service ID: %s", username, matchedApplicationService.ID)), } } @@ -449,7 +449,7 @@ func validateApplicationService( if UsernameMatchesMultipleExclusiveNamespaces(cfg, userID) { return "", &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.ASExclusive(fmt.Sprintf( + JSON: spec.ASExclusive(fmt.Sprintf( "Supplied username %s matches multiple exclusive application service namespaces. Only 1 match allowed", username)), } } @@ -475,12 +475,12 @@ func Register( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("Unable to read request body"), + JSON: spec.NotJSON("Unable to read request body"), } } var r registerRequest - host := gomatrixserverlib.ServerName(req.Host) + host := spec.ServerName(req.Host) if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil { r.ServerName = v.ServerName } else { @@ -519,7 +519,7 @@ func Register( if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), + JSON: spec.InvalidUsername("Numeric user IDs are reserved"), } } // Auto generate a numeric username if r.Username is empty @@ -530,7 +530,10 @@ func Register( nres := &userapi.QueryNumericLocalpartResponse{} if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } r.Username = strconv.FormatInt(nres.ID, 10) } @@ -553,7 +556,7 @@ func Register( // type is not known or specified) return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("A known registration type (e.g. m.login.application_service) must be specified if an access_token is provided"), + JSON: spec.MissingParam("A known registration type (e.g. m.login.application_service) must be specified if an access_token is provided"), } default: // Spec-compliant case (neither the access_token nor the login type are @@ -591,7 +594,7 @@ func handleGuestRegistration( if !registrationEnabled || !guestsEnabled { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden( + JSON: spec.Forbidden( fmt.Sprintf("Guest registration is disabled on %q", r.ServerName), ), } @@ -605,7 +608,7 @@ func handleGuestRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create account: " + err.Error()), + JSON: spec.Unknown("failed to create account: " + err.Error()), } } token, err := tokens.GenerateLoginToken(tokens.TokenOptions{ @@ -617,7 +620,7 @@ func handleGuestRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Failed to generate access token"), + JSON: spec.Unknown("Failed to generate access token"), } } //we don't allow guests to specify their own device_id @@ -633,7 +636,7 @@ func handleGuestRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create device: " + err.Error()), + JSON: spec.Unknown("failed to create device: " + err.Error()), } } return util.JSONResponse{ @@ -683,7 +686,7 @@ func handleRegistrationFlow( if !registrationEnabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden( + JSON: spec.Forbidden( fmt.Sprintf("Registration is disabled on %q", r.ServerName), ), } @@ -697,7 +700,7 @@ func handleRegistrationFlow( UsernameMatchesExclusiveNamespaces(cfg, r.Username) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.ASExclusive("This username is reserved by an application service."), + JSON: spec.ASExclusive("This username is reserved by an application service."), } } @@ -708,15 +711,15 @@ func handleRegistrationFlow( err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) switch err { case ErrCaptchaDisabled: - return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())} + return util.JSONResponse{Code: http.StatusForbidden, JSON: spec.Unknown(err.Error())} case ErrMissingResponse: - return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())} + return util.JSONResponse{Code: http.StatusBadRequest, JSON: spec.BadJSON(err.Error())} case ErrInvalidCaptcha: - return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())} + return util.JSONResponse{Code: http.StatusUnauthorized, JSON: spec.BadJSON(err.Error())} case nil: default: util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") - return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()} + return util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}} } // Add Recaptcha to the list of completed registration stages @@ -737,12 +740,15 @@ func handleRegistrationFlow( bound, threePid.Address, threePid.Medium, err = threepid.CheckAssociation(req.Context(), r.Auth.ThreePidCreds, cfg, nil) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !bound { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MatrixError{ + JSON: spec.MatrixError{ ErrCode: "M_THREEPID_AUTH_FAILED", Err: "Failed to auth 3pid", }, @@ -757,7 +763,7 @@ func handleRegistrationFlow( default: return util.JSONResponse{ Code: http.StatusNotImplemented, - JSON: jsonerror.Unknown("unknown/unimplemented auth type"), + JSON: spec.Unknown("unknown/unimplemented auth type"), } } @@ -789,7 +795,7 @@ func handleApplicationServiceRegistration( if tokenErr != nil { return util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.MissingToken(tokenErr.Error()), + JSON: spec.MissingToken(tokenErr.Error()), } } @@ -849,7 +855,7 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.ClientUserAPI, - username string, serverName gomatrixserverlib.ServerName, displayName string, + username string, serverName spec.ServerName, displayName string, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, deviceDisplayName, deviceID *string, @@ -859,14 +865,14 @@ func completeRegistration( if username == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Missing username"), + JSON: spec.MissingParam("Missing username"), } } // Blank passwords are only allowed by registered application services if password == "" && appserviceID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Missing password"), + JSON: spec.MissingParam("Missing password"), } } var accRes userapi.PerformAccountCreationResponse @@ -882,12 +888,12 @@ func completeRegistration( if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired user ID is already taken."), + JSON: spec.UserInUse("Desired user ID is already taken."), } } return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create account: " + err.Error()), + JSON: spec.Unknown("failed to create account: " + err.Error()), } } @@ -905,7 +911,7 @@ func completeRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Failed to save 3PID association: " + err.Error()), + JSON: spec.Unknown("Failed to save 3PID association: " + err.Error()), } } } @@ -925,7 +931,7 @@ func completeRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("Failed to generate access token"), + JSON: spec.Unknown("Failed to generate access token"), } } @@ -934,7 +940,7 @@ func completeRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to set display name: " + err.Error()), + JSON: spec.Unknown("failed to set display name: " + err.Error()), } } } @@ -952,7 +958,7 @@ func completeRegistration( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to create device: " + err.Error()), + JSON: spec.Unknown("failed to create device: " + err.Error()), } } @@ -1036,7 +1042,7 @@ func RegisterAvailable( // Squash username to all lowercase letters username = strings.ToLower(username) domain := cfg.Matrix.ServerName - host := gomatrixserverlib.ServerName(req.Host) + host := spec.ServerName(req.Host) if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil { domain = v.ServerName } @@ -1047,7 +1053,7 @@ func RegisterAvailable( if v.ServerName == domain && !v.AllowRegistration { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden( + JSON: spec.Forbidden( fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)), ), } @@ -1064,7 +1070,7 @@ func RegisterAvailable( if appservice.OwnsNamespaceCoveringUserId(userID) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired user ID is reserved by an application service."), + JSON: spec.UserInUse("Desired user ID is reserved by an application service."), } } } @@ -1077,14 +1083,14 @@ func RegisterAvailable( if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("failed to check availability:" + err.Error()), + JSON: spec.Unknown("failed to check availability:" + err.Error()), } } if !res.Available { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired User ID is already taken."), + JSON: spec.UserInUse("Desired User ID is already taken."), } } @@ -1101,7 +1107,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if err != nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("malformed json: %s", err)), + JSON: spec.BadJSON(fmt.Sprintf("malformed json: %s", err)), } } valid, err := sr.IsValidMacLogin(ssrr.Nonce, ssrr.User, ssrr.Password, ssrr.Admin, ssrr.MacBytes) @@ -1111,7 +1117,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if !valid { return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("bad mac"), + JSON: spec.Forbidden("bad mac"), } } // downcase capitals diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 50e32283e0..5b7855ad6c 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -28,7 +28,6 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -39,6 +38,7 @@ import ( "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" @@ -306,7 +306,7 @@ func Test_register(t *testing.T) { guestsDisabled: true, wantResponse: util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(`Guest registration is disabled on "test"`), + JSON: spec.Forbidden(`Guest registration is disabled on "test"`), }, }, { @@ -318,7 +318,7 @@ func Test_register(t *testing.T) { loginType: "im.not.known", wantResponse: util.JSONResponse{ Code: http.StatusNotImplemented, - JSON: jsonerror.Unknown("unknown/unimplemented auth type"), + JSON: spec.Unknown("unknown/unimplemented auth type"), }, }, { @@ -326,7 +326,7 @@ func Test_register(t *testing.T) { registrationDisabled: true, wantResponse: util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(`Registration is disabled on "test"`), + JSON: spec.Forbidden(`Registration is disabled on "test"`), }, }, { @@ -344,7 +344,7 @@ func Test_register(t *testing.T) { username: "success", wantResponse: util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired user ID is already taken."), + JSON: spec.UserInUse("Desired user ID is already taken."), }, }, { @@ -361,7 +361,7 @@ func Test_register(t *testing.T) { username: "1337", wantResponse: util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), + JSON: spec.InvalidUsername("Numeric user IDs are reserved"), }, }, { @@ -369,7 +369,7 @@ func Test_register(t *testing.T) { loginType: authtypes.LoginTypeRecaptcha, wantResponse: util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Unknown(ErrCaptchaDisabled.Error()), + JSON: spec.Unknown(ErrCaptchaDisabled.Error()), }, }, { @@ -378,7 +378,7 @@ func Test_register(t *testing.T) { loginType: authtypes.LoginTypeRecaptcha, wantResponse: util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(ErrMissingResponse.Error()), + JSON: spec.BadJSON(ErrMissingResponse.Error()), }, }, { @@ -388,7 +388,7 @@ func Test_register(t *testing.T) { captchaBody: `notvalid`, wantResponse: util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON(ErrInvalidCaptcha.Error()), + JSON: spec.BadJSON(ErrInvalidCaptcha.Error()), }, }, { @@ -402,7 +402,7 @@ func Test_register(t *testing.T) { enableRecaptcha: true, loginType: authtypes.LoginTypeRecaptcha, captchaBody: `i should fail for other reasons`, - wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}, + wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}}, }, } @@ -484,7 +484,7 @@ func Test_register(t *testing.T) { if !reflect.DeepEqual(r.Flows, cfg.Derived.Registration.Flows) { t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, cfg.Derived.Registration.Flows) } - case *jsonerror.MatrixError: + case spec.MatrixError: if !reflect.DeepEqual(tc.wantResponse, resp) { t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) } @@ -541,7 +541,12 @@ func Test_register(t *testing.T) { resp = Register(req, userAPI, &cfg.ClientAPI) switch resp.JSON.(type) { - case *jsonerror.MatrixError: + case spec.InternalServerError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + case spec.MatrixError: if !reflect.DeepEqual(tc.wantResponse, resp) { t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) } diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index 92b9e66553..5a5296bf4f 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -19,10 +19,10 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -39,14 +39,17 @@ func GetTags( if device.UserID != userID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot retrieve another user's tags"), + JSON: spec.Forbidden("Cannot retrieve another user's tags"), } } tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -71,7 +74,7 @@ func PutTag( if device.UserID != userID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot modify another user's tags"), + JSON: spec.Forbidden("Cannot modify another user's tags"), } } @@ -83,7 +86,10 @@ func PutTag( tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if tagContent.Tags == nil { @@ -93,7 +99,10 @@ func PutTag( if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -118,14 +127,17 @@ func DeleteTag( if device.UserID != userID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot modify another user's tags"), + JSON: spec.Forbidden("Cannot modify another user's tags"), } } tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Check whether the tag to be deleted exists @@ -141,7 +153,10 @@ func DeleteTag( if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d42ff81e7c..f34eec17f5 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -20,20 +20,21 @@ import ( "strings" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/setup/base" - userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "golang.org/x/sync/singleflight" + + "github.com/matrix-org/dendrite/setup/base" + userapi "github.com/matrix-org/dendrite/userapi/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth" clientutil "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/ratelimit" federationAPI "github.com/matrix-org/dendrite/federationapi/api" @@ -57,7 +58,7 @@ func Setup( asAPI appserviceAPI.AppServiceInternalAPI, userAPI userapi.ClientUserAPI, userDirectoryProvider userapi.QuerySearchProfilesAPI, - federation *fclient.FederationClient, + federation fclient.FederationClient, syncProducer *producers.SyncAPIProducer, transactionsCache *transactions.Cache, federationSender federationAPI.ClientFederationAPI, @@ -87,6 +88,14 @@ func Setup( unstableFeatures["org.matrix."+msc] = true } + // singleflight protects /join endpoints from being invoked + // multiple times from the same user and room, otherwise + // a state reset can occur. This also avoids unneeded + // state calculations. + // TODO: actually fix this in the roomserver, as there are + // possibly other ways that can result in a stat reset. + sf := singleflight.Group{} + if cfg.Matrix.WellKnownClientName != "" { logrus.Infof("Setting m.homeserver base_url as %s at /.well-known/matrix/client", cfg.Matrix.WellKnownClientName) wkMux.Handle("/client", httputil.MakeExternalAPI("wellknown", func(r *http.Request) util.JSONResponse { @@ -150,11 +159,41 @@ func Setup( } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("unknown method"), + JSON: spec.NotFound("unknown method"), } }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } + dendriteAdminRouter.Handle("/admin/registrationTokens/new", + httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminCreateNewRegistrationToken(req, cfg, userAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + + dendriteAdminRouter.Handle("/admin/registrationTokens", + httputil.MakeAdminAPI("admin_list_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminListRegistrationTokens(req, cfg, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + dendriteAdminRouter.Handle("/admin/registrationTokens/{token}", + httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + switch req.Method { + case http.MethodGet: + return AdminGetRegistrationToken(req, cfg, userAPI) + case http.MethodPut: + return AdminUpdateRegistrationToken(req, cfg, userAPI) + case http.MethodDelete: + return AdminDeleteRegistrationToken(req, cfg, userAPI) + default: + return util.MatrixErrorResponse( + 404, + string(spec.ErrorNotFound), + "unknown method", + ) + } + }), + ).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -164,13 +203,13 @@ func Setup( dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}", httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminEvacuateUser(req, cfg, rsAPI) + return AdminEvacuateUser(req, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminPurgeRoom(req, cfg, device, rsAPI) + return AdminPurgeRoom(req, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -182,7 +221,7 @@ func Setup( dendriteAdminRouter.Handle("/admin/downloadState/{serverName}/{roomID}", httputil.MakeAdminAPI("admin_download_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminDownloadState(req, cfg, device, rsAPI) + return AdminDownloadState(req, device, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -259,7 +298,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/join/{roomIDOrAlias}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -267,15 +306,23 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return JoinRoomByIDOrAlias( - req, device, rsAPI, userAPI, vars["roomIDOrAlias"], - ) + // Only execute a join for roomIDOrAlias and UserID once. If there is a join in progress + // it waits for it to complete and returns that result for subsequent requests. + resp, _, _ := sf.Do(vars["roomIDOrAlias"]+device.UserID, func() (any, error) { + return JoinRoomByIDOrAlias( + req, device, rsAPI, userAPI, vars["roomIDOrAlias"], + ), nil + }) + // once all joins are processed, drop them from the cache. Further requests + // will be processed as usual. + sf.Forget(vars["roomIDOrAlias"] + device.UserID) + return resp.(util.JSONResponse) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) if mscCfg.Enabled("msc2753") { v3mux.Handle("/peek/{roomIDOrAlias}", - httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -295,7 +342,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/join", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } @@ -303,9 +350,17 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return JoinRoomByIDOrAlias( - req, device, rsAPI, userAPI, vars["roomID"], - ) + // Only execute a join for roomID and UserID once. If there is a join in progress + // it waits for it to complete and returns that result for subsequent requests. + resp, _, _ := sf.Do(vars["roomID"]+device.UserID, func() (any, error) { + return JoinRoomByIDOrAlias( + req, device, rsAPI, userAPI, vars["roomID"], + ), nil + }) + // once all joins are processed, drop them from the cache. Further requests + // will be processed as usual. + sf.Forget(vars["roomID"] + device.UserID) + return resp.(util.JSONResponse) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/leave", @@ -672,7 +727,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("missing trailing slash"), + JSON: spec.InvalidParam("missing trailing slash"), } }), ).Methods(http.MethodGet, http.MethodOptions) @@ -687,7 +742,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"), + JSON: spec.InvalidParam("scope, kind and rule ID must be specified"), } }), ).Methods(http.MethodPut) @@ -706,7 +761,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"), + JSON: spec.InvalidParam("missing trailing slash after scope"), } }), ).Methods(http.MethodGet, http.MethodOptions) @@ -715,7 +770,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"), + JSON: spec.InvalidParam("kind and rule ID must be specified"), } }), ).Methods(http.MethodPut) @@ -734,7 +789,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"), + JSON: spec.InvalidParam("missing trailing slash after kind"), } }), ).Methods(http.MethodGet, http.MethodOptions) @@ -743,7 +798,7 @@ func Setup( httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"), + JSON: spec.InvalidParam("rule ID must be specified"), } }), ).Methods(http.MethodPut) @@ -952,7 +1007,7 @@ func Setup( // TODO: Allow people to peek into rooms. return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.GuestAccessForbidden("Guest access not implemented"), + JSON: spec.GuestAccessForbidden("Guest access not implemented"), } }), ).Methods(http.MethodGet, http.MethodOptions) @@ -1120,7 +1175,7 @@ func Setup( v3mux.Handle("/delete_devices", httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return DeleteDevices(req, userAPI, device) + return DeleteDevices(req, userInteractiveAuth, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -1257,7 +1312,7 @@ func Setup( if version == "" { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), + JSON: spec.InvalidParam("version must be specified"), } } var reqBody keyBackupSessionRequest @@ -1278,7 +1333,7 @@ func Setup( if version == "" { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), + JSON: spec.InvalidParam("version must be specified"), } } roomID := vars["roomID"] @@ -1310,7 +1365,7 @@ func Setup( if version == "" { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("version must be specified"), + JSON: spec.InvalidParam("version must be specified"), } } var reqBody userapi.KeyBackupSession @@ -1411,7 +1466,7 @@ func Setup( }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { return *r } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 90af9ac4d1..41a3793ae2 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -23,18 +23,18 @@ import ( "sync" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" ) // http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid @@ -67,6 +67,8 @@ var sendEventDuration = prometheus.NewHistogramVec( // /rooms/{roomID}/send/{eventType} // /rooms/{roomID}/send/{eventType}/{txnID} // /rooms/{roomID}/state/{eventType}/{stateKey} +// +// nolint: gocyclo func SendEvent( req *http.Request, device *userapi.Device, @@ -75,12 +77,11 @@ func SendEvent( rsAPI api.ClientRoomserverAPI, txnCache *transactions.Cache, ) util.JSONResponse { - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := rsAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil { + roomVersion, err := rsAPI.QueryRoomVersionForRoom(req.Context(), roomID) + if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(err.Error()), + JSON: spec.UnsupportedRoomVersion(err.Error()), } } @@ -117,26 +118,37 @@ func SendEvent( // If we're sending a membership update, make sure to strip the authorised // via key if it is present, otherwise other servers won't be able to auth // the event if the room is set to the "restricted" join rule. - if eventType == gomatrixserverlib.MRoomMember { + if eventType == spec.MRoomMember { delete(r, "join_authorised_via_users_server") } + // for power level events we need to replace the userID with the pseudoID + if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels { + err = updatePowerLevels(req, r, roomID, rsAPI) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: err.Error()}, + } + } + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } - e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, cfg, rsAPI, evTime) + e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime) if resErr != nil { return *resErr } timeToGenerateEvent := time.Since(startedGeneratingEvent) // validate that the aliases exists - if eventType == gomatrixserverlib.MRoomCanonicalAlias && stateKey != nil && *stateKey == "" { + if eventType == spec.MRoomCanonicalAlias && stateKey != nil && *stateKey == "" { aliasReq := api.AliasEvent{} if err = json.Unmarshal(e.Content(), &aliasReq); err != nil { return util.ErrorResponse(fmt.Errorf("unable to parse alias event: %w", err)) @@ -144,12 +156,15 @@ func SendEvent( if !aliasReq.Valid() { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam("Request contains invalid aliases."), + JSON: spec.InvalidParam("Request contains invalid aliases."), } } aliasRes := &api.GetAliasesForRoomIDResponse{} if err = rsAPI.GetAliasesForRoomID(req.Context(), &api.GetAliasesForRoomIDRequest{RoomID: roomID}, aliasRes); err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var found int requestAliases := append(aliasReq.AltAliases, aliasReq.Alias) @@ -164,7 +179,7 @@ func SendEvent( if aliasReq.Alias != "" && found < len(requestAliases) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadAlias("No matching alias found."), + JSON: spec.BadAlias("No matching alias found."), } } } @@ -183,8 +198,8 @@ func SendEvent( if err := api.SendEvents( req.Context(), rsAPI, api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - e.Headered(verRes.RoomVersion), + []*types.HeaderedEvent{ + &types.HeaderedEvent{PDU: e}, }, device.UserDomain(), domain, @@ -193,13 +208,16 @@ func SendEvent( false, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } timeToSubmitEvent := time.Since(startedSubmittingEvent) util.GetLogger(req.Context()).WithFields(logrus.Fields{ "event_id": e.EventID(), "room_id": roomID, - "room_version": verRes.RoomVersion, + "room_version": roomVersion, }).Info("Sent event to roomserver") res := util.JSONResponse{ @@ -219,6 +237,28 @@ func SendEvent( return res } +func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID string, rsAPI api.ClientRoomserverAPI) error { + userMap := r["users"].(map[string]interface{}) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + for user, level := range userMap { + uID, err := spec.NewUserID(user, true) + if err != nil { + continue // we're modifying the map in place, so we're going to have invalid userIDs after the first iteration + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *uID) + if err != nil { + return err + } + userMap[string(senderID)] = level + delete(userMap, user) + } + r["users"] = userMap + return nil +} + // stateEqual compares the new and the existing state event content. If they are equal, returns a *util.JSONResponse // with the existing event_id, making this an idempotent request. func stateEqual(ctx context.Context, rsAPI api.ClientRoomserverAPI, eventType, stateKey, roomID string, newContent map[string]interface{}) *util.JSONResponse { @@ -255,72 +295,101 @@ func generateSendEvent( r map[string]interface{}, device *userapi.Device, roomID, eventType string, stateKey *string, - cfg *config.ClientAPI, rsAPI api.ClientRoomserverAPI, evTime time.Time, -) (*gomatrixserverlib.Event, *util.JSONResponse) { +) (gomatrixserverlib.PDU, *util.JSONResponse) { // parse the incoming http request - userID := device.UserID + fullUserID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("Bad userID"), + } + } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("Unable to find senderID for user"), + } + } // create the new event and set all the fields we can - builder := gomatrixserverlib.EventBuilder{ - Sender: userID, + proto := gomatrixserverlib.ProtoEvent{ + SenderID: string(senderID), RoomID: roomID, Type: eventType, StateKey: stateKey, } - err := builder.SetContent(r) + err = proto.SetContent(r) if err != nil { - util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") - resErr := jsonerror.InternalServerError() - return nil, &resErr + util.GetLogger(ctx).WithError(err).Error("proto.SetContent failed") + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *fullUserID) if err != nil { - resErr := jsonerror.InternalServerError() - return nil, &resErr + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes) - if err == eventutil.ErrRoomNoExists { + e, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, evTime, rsAPI, &queryRes) + switch specificErr := err.(type) { + case nil: + case eventutil.ErrRoomNoExists: return nil, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + JSON: spec.NotFound("Room does not exist"), } - } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + case gomatrixserverlib.BadJSONError: return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(specificErr.Error()), } - } else if e, ok := err.(gomatrixserverlib.EventValidationError); ok { - if e.Code == gomatrixserverlib.EventValidationTooLarge { + case gomatrixserverlib.EventValidationError: + if specificErr.Code == gomatrixserverlib.EventValidationTooLarge { return nil, &util.JSONResponse{ Code: http.StatusRequestEntityTooLarge, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(specificErr.Error()), } } return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: spec.BadJSON(specificErr.Error()), } - } else if err != nil { + default: util.GetLogger(ctx).WithError(err).Error("eventutil.BuildEvent failed") - resErr := jsonerror.InternalServerError() - return nil, &resErr + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // check to see if this user can perform this operation - stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) + stateEvents := make([]gomatrixserverlib.PDU, len(queryRes.StateEvents)) for i := range queryRes.StateEvents { - stateEvents[i] = queryRes.StateEvents[i].Event + stateEvents[i] = queryRes.StateEvents[i].PDU } - provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(e.Event, &provider); err != nil { + provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) + if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, *validRoomID, senderID) + }); err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client? + JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client? } } @@ -331,16 +400,16 @@ func generateSendEvent( util.GetLogger(ctx).WithError(err).Error("Cannot unmarshal the event content.") return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Cannot unmarshal the event content."), + JSON: spec.BadJSON("Cannot unmarshal the event content."), } } if content["replacement_room"] == e.RoomID() { return nil, &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam("Cannot send tombstone event that points to the same room."), + JSON: spec.InvalidParam("Cannot send tombstone event that points to the same room."), } } } - return e.Event, nil + return e.PDU, nil } diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go index 0dd1d26217..7b5499a62f 100644 --- a/clientapi/routing/sendtodevice.go +++ b/clientapi/routing/sendtodevice.go @@ -20,10 +20,10 @@ import ( "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/internal/transactions" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId} @@ -54,7 +54,10 @@ func SendToDevice( req.Context(), device.UserID, userID, deviceID, eventType, message, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("eduProducer.SendToDevice failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } logrus.WithFields(logrus.Fields{ "to_device_id": deviceID, diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 9dc884d627..979bced3b2 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -18,10 +18,10 @@ import ( "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type typingContentJSON struct { @@ -39,12 +39,20 @@ func SendTyping( if device.UserID != userID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot set another user's typing state"), + JSON: spec.Forbidden("Cannot set another user's typing state"), + } + } + + deviceUserID, err := spec.NewUserID(userID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), } } // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), rsAPI, userID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID) if resErr != nil { return *resErr } @@ -58,7 +66,10 @@ func SendTyping( if err := syncProducer.SendTyping(req.Context(), userID, roomID, r.Typing, r.Timeout); err != nil { util.GetLogger(req.Context()).WithError(err).Error("eduProducer.Send failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index d6191f3b42..66258a68ac 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -22,22 +22,22 @@ import ( "time" "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/version" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/transactions" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Unspecced server notice request @@ -52,6 +52,7 @@ type sendServerNoticeRequest struct { StateKey string `json:"state_key,omitempty"` } +// nolint:gocyclo // SendServerNotice sends a message to a specific user. It can only be invoked by an admin. func SendServerNotice( req *http.Request, @@ -68,7 +69,7 @@ func SendServerNotice( if device.AccountType != userapi.AccountTypeAdmin { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("This API can only be used by admin users."), + JSON: spec.Forbidden("This API can only be used by admin users."), } } @@ -90,7 +91,7 @@ func SendServerNotice( if !r.valid() { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Invalid request"), + JSON: spec.BadJSON("Invalid request"), } } @@ -155,9 +156,8 @@ func SendServerNotice( Invite: []string{r.UserID}, Name: cfgNotices.RoomName, Visibility: "private", - Preset: presetPrivateChat, + Preset: spec.PresetPrivateChat, CreationContent: cc, - GuestCanJoin: false, RoomVersion: roomVersion, PowerLevelContentOverride: pl, } @@ -176,7 +176,10 @@ func SendServerNotice( }} if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil { util.GetLogger(ctx).WithError(err).Error("saveTagData failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } default: @@ -185,12 +188,23 @@ func SendServerNotice( } } else { // we've found a room in common, check the membership + deviceUserID, err := spec.NewUserID(r.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("userID doesn't have power level to change visibility"), + } + } + roomID = commonRooms[0] membershipRes := api.QueryMembershipForUserResponse{} - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes) + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !membershipRes.IsInRoom { // re-invite the user @@ -207,7 +221,7 @@ func SendServerNotice( "body": r.Content.Body, "msgtype": r.Content.MsgType, } - e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now()) + e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, rsAPI, time.Now()) if resErr != nil { logrus.Errorf("failed to send message: %+v", resErr) return *resErr @@ -228,8 +242,8 @@ func SendServerNotice( if err := api.SendEvents( ctx, rsAPI, api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - e.Headered(roomVersion), + []*types.HeaderedEvent{ + {PDU: e}, }, device.UserDomain(), cfgClient.Matrix.ServerName, @@ -238,7 +252,10 @@ func SendServerNotice( false, ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": e.EventID(), @@ -333,7 +350,7 @@ func getSenderDevice( if len(deviceRes.Devices) > 0 { // If there were changes to the profile, create a new membership event if displayNameChanged || avatarChanged { - _, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, cfg, time.Now()) + _, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, time.Now()) if err != nil { return nil, err } diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index d24d8c9fb9..099dfc0105 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -20,11 +20,12 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -55,12 +56,15 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a StateToFetch: []gomatrixserverlib.StateKeyTuple{}, }, &stateRes); err != nil { util.GetLogger(ctx).WithError(err).Error("queryAPI.QueryLatestEventsAndState failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !stateRes.RoomExists { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("room does not exist"), + JSON: spec.Forbidden("room does not exist"), } } @@ -68,11 +72,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a // that marks the room as world-readable. If we don't then we assume that // the room is not world-readable. for _, ev := range stateRes.StateEvents { - if ev.Type() == gomatrixserverlib.MRoomHistoryVisibility { + if ev.Type() == spec.MRoomHistoryVisibility { content := map[string]string{} if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for history visibility failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if visibility, ok := content["history_visibility"]; ok { worldReadable = visibility == "world_readable" @@ -92,20 +99,31 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a if !worldReadable { // The room isn't world-readable so try to work out based on the // user's membership if we want the latest state or not. - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("UserID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("Device UserID is invalid"), + } + } + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *userID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // If the user has never been in the room then stop at this point. // We won't tell the user about a room they have never joined. - if !membershipRes.HasBeenInRoom && membershipRes.Membership != gomatrixserverlib.Invite { + if !membershipRes.HasBeenInRoom && membershipRes.Membership != spec.Invite { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), + JSON: spec.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), } } // Otherwise, if the user has been in the room, whether or not we @@ -132,7 +150,9 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateRes.StateEvents { stateEvents = append( stateEvents, - synctypes.HeaderedToClientEvent(ev, synctypes.FormatAll), + synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, ev), ) } } else { @@ -146,12 +166,34 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a }, &stateAfterRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } for _, ev := range stateAfterRes.StateEvents { + sender := spec.UserID{} + evRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Event roomID is invalid") + continue + } + userID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, ev.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + + sk := ev.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*ev.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } stateEvents = append( stateEvents, - synctypes.HeaderedToClientEvent(ev, synctypes.FormatAll), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk), ) } } @@ -185,9 +227,9 @@ func OnIncomingStateTypeRequest( StateKey: stateKey, }, } - if evType != gomatrixserverlib.MRoomHistoryVisibility && stateKey != "" { + if evType != spec.MRoomHistoryVisibility && stateKey != "" { stateToFetch = append(stateToFetch, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomHistoryVisibility, + EventType: spec.MRoomHistoryVisibility, StateKey: "", }) } @@ -201,18 +243,24 @@ func OnIncomingStateTypeRequest( StateToFetch: stateToFetch, }, &stateRes); err != nil { util.GetLogger(ctx).WithError(err).Error("queryAPI.QueryLatestEventsAndState failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Look at the room state and see if we have a history visibility event // that marks the room as world-readable. If we don't then we assume that // the room is not world-readable. for _, ev := range stateRes.StateEvents { - if ev.Type() == gomatrixserverlib.MRoomHistoryVisibility { + if ev.Type() == spec.MRoomHistoryVisibility { content := map[string]string{} if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for history visibility failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if visibility, ok := content["history_visibility"]; ok { worldReadable = visibility == "world_readable" @@ -230,22 +278,33 @@ func OnIncomingStateTypeRequest( // membershipRes will only be populated if the room is not world-readable. var membershipRes api.QueryMembershipForUserResponse if !worldReadable { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("UserID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("Device UserID is invalid"), + } + } // The room isn't world-readable so try to work out based on the // user's membership if we want the latest state or not. - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *userID, }, &membershipRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // If the user has never been in the room then stop at this point. // We won't tell the user about a room they have never joined. - if !membershipRes.HasBeenInRoom && membershipRes.Membership != gomatrixserverlib.Invite || membershipRes.Membership == gomatrixserverlib.Ban { + if !membershipRes.HasBeenInRoom || membershipRes.Membership == spec.Ban { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), + JSON: spec.Forbidden(fmt.Sprintf("Unknown room %q or user %q has never joined this room", roomID, device.UserID)), } } // Otherwise, if the user has been in the room, whether or not we @@ -265,7 +324,7 @@ func OnIncomingStateTypeRequest( "state_at_event": !wantLatestState, }).Info("Fetching state") - var event *gomatrixserverlib.HeaderedEvent + var event *types.HeaderedEvent if wantLatestState { // If we are happy to use the latest state, either because the user is // still in the room, or because the room is world-readable, then just @@ -293,7 +352,10 @@ func OnIncomingStateTypeRequest( }, &stateAfterRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if len(stateAfterRes.StateEvents) > 0 { event = stateAfterRes.StateEvents[0] @@ -305,12 +367,14 @@ func OnIncomingStateTypeRequest( if event == nil { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("Cannot find state event for %q", evType)), + JSON: spec.NotFound(fmt.Sprintf("Cannot find state event for %q", evType)), } } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.HeaderedToClientEvent(event, synctypes.FormatAll), + ClientEvent: synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, event), } var res interface{} diff --git a/clientapi/routing/thirdparty.go b/clientapi/routing/thirdparty.go index 7a62da4491..b805d4b51c 100644 --- a/clientapi/routing/thirdparty.go +++ b/clientapi/routing/thirdparty.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/util" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Protocols implements @@ -33,13 +33,16 @@ func Protocols(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, dev resp := &appserviceAPI.ProtocolResponse{} if err := asAPI.Protocols(req.Context(), &appserviceAPI.ProtocolRequest{Protocol: protocol}, resp); err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !resp.Exists { if protocol != "" { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The protocol is unknown."), + JSON: spec.NotFound("The protocol is unknown."), } } return util.JSONResponse{ @@ -71,12 +74,15 @@ func User(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, device * Protocol: protocol, Params: params.Encode(), }, resp); err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !resp.Exists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The Matrix User ID was not found"), + JSON: spec.NotFound("The Matrix User ID was not found"), } } return util.JSONResponse{ @@ -97,12 +103,15 @@ func Location(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, devi Protocol: protocol, Params: params.Encode(), }, resp); err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !resp.Exists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("No portal rooms were found."), + JSON: spec.NotFound("No portal rooms were found."), } } return util.JSONResponse{ diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index 102b1d1cbd..5261a14070 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -19,12 +19,12 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -60,28 +60,37 @@ func RequestEmailToken(req *http.Request, threePIDAPI api.ClientUserAPI, cfg *co if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.QueryLocalpartForThreePID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if len(res.Localpart) > 0 { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MatrixError{ - ErrCode: "M_THREEPID_IN_USE", + JSON: spec.MatrixError{ + ErrCode: spec.ErrorThreePIDInUse, Err: userdb.Err3PIDInUse.Error(), }, } } resp.SID, err = threepid.CreateSession(req.Context(), body, cfg, client) - if err == threepid.ErrNotTrusted { + switch err.(type) { + case nil: + case threepid.ErrNotTrusted: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CreateSession failed") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotTrusted(body.IDServer), + JSON: spec.NotTrusted(body.IDServer), } - } else if err != nil { + default: util.GetLogger(req.Context()).WithError(err).Error("threepid.CreateSession failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -102,21 +111,27 @@ func CheckAndSave3PIDAssociation( // Check if the association has been validated verified, address, medium, err := threepid.CheckAssociation(req.Context(), body.Creds, cfg, client) - if err == threepid.ErrNotTrusted { + switch err.(type) { + case nil: + case threepid.ErrNotTrusted: + util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotTrusted(body.Creds.IDServer), + JSON: spec.NotTrusted(body.Creds.IDServer), } - } else if err != nil { + default: util.GetLogger(req.Context()).WithError(err).Error("threepid.CheckAssociation failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !verified { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MatrixError{ - ErrCode: "M_THREEPID_AUTH_FAILED", + JSON: spec.MatrixError{ + ErrCode: spec.ErrorThreePIDAuthFailed, Err: "Failed to auth 3pid", }, } @@ -127,7 +142,10 @@ func CheckAndSave3PIDAssociation( err = threepid.PublishAssociation(req.Context(), body.Creds, device.UserID, cfg, client) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepid.PublishAssociation failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } @@ -135,7 +153,10 @@ func CheckAndSave3PIDAssociation( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{ @@ -145,7 +166,10 @@ func CheckAndSave3PIDAssociation( Medium: medium, }, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -161,7 +185,10 @@ func GetAssociated3PIDs( localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } res := &api.QueryThreePIDsForLocalpartResponse{} @@ -171,7 +198,10 @@ func GetAssociated3PIDs( }, res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -192,7 +222,10 @@ func Forget3PID(req *http.Request, threepidAPI api.ClientUserAPI) util.JSONRespo Medium: body.Medium, }, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.PerformForgetThreePID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/clientapi/routing/upgrade_room.go b/clientapi/routing/upgrade_room.go index 34c7eb0049..03c0230e67 100644 --- a/clientapi/routing/upgrade_room.go +++ b/clientapi/routing/upgrade_room.go @@ -15,16 +15,18 @@ package routing import ( + "errors" "net/http" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/eventutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -53,42 +55,43 @@ func UpgradeRoom( if _, err := version.SupportedRoomVersion(gomatrixserverlib.RoomVersion(r.NewVersion)); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion("This server does not support that room version"), + JSON: spec.UnsupportedRoomVersion("This server does not support that room version"), } } - upgradeReq := roomserverAPI.PerformRoomUpgradeRequest{ - UserID: device.UserID, - RoomID: roomID, - RoomVersion: gomatrixserverlib.RoomVersion(r.NewVersion), - } - upgradeResp := roomserverAPI.PerformRoomUpgradeResponse{} - - if err := rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("device UserID is invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - - if upgradeResp.Error != nil { - if upgradeResp.Error.Code == roomserverAPI.PerformErrorNoRoom { + newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, *userID, gomatrixserverlib.RoomVersion(r.NewVersion)) + switch e := err.(type) { + case nil: + case roomserverAPI.ErrNotAllowed: + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden(e.Error()), + } + default: + if errors.Is(err, eventutil.ErrRoomNoExists{}) { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), - } - } else if upgradeResp.Error.Code == roomserverAPI.PerformErrorNotAllowed { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(upgradeResp.Error.Msg), + JSON: spec.NotFound("Room does not exist"), } - } else { - return jsonerror.InternalServerError() } - + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, JSON: upgradeRoomResponse{ - ReplacementRoom: upgradeResp.NewRoomID, + ReplacementRoom: newRoomID, }, } } diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go index a4cf8e9c2d..32cefde63a 100644 --- a/clientapi/routing/userdirectory.go +++ b/clientapi/routing/userdirectory.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -42,8 +43,8 @@ func SearchUserDirectory( provider userapi.QuerySearchProfilesAPI, searchString string, limit int, - federation *fclient.FederationClient, - localServerName gomatrixserverlib.ServerName, + federation fclient.FederationClient, + localServerName spec.ServerName, ) util.JSONResponse { if limit < 10 { limit = 10 diff --git a/clientapi/routing/voip.go b/clientapi/routing/voip.go index f0f69ce3c2..14a08b79c8 100644 --- a/clientapi/routing/voip.go +++ b/clientapi/routing/voip.go @@ -25,9 +25,9 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // RequestTurnServer implements: @@ -60,7 +60,10 @@ func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.Client if err != nil { util.GetLogger(req.Context()).WithError(err).Error("mac.Write failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } resp.Password = base64.StdEncoding.EncodeToString(mac.Sum(nil)) diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index a9910b782a..d15cc6d464 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -27,9 +27,11 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // MembershipRequest represents the body of an incoming POST request @@ -62,14 +64,34 @@ type idServerStoreInviteResponse struct { } var ( - // ErrMissingParameter is the error raised if a request for 3PID invite has - // an incomplete body - ErrMissingParameter = errors.New("'address', 'id_server' and 'medium' must all be supplied") - // ErrNotTrusted is the error raised if an identity server isn't in the list - // of trusted servers in the configuration file. - ErrNotTrusted = errors.New("untrusted server") + errMissingParameter = fmt.Errorf("'address', 'id_server' and 'medium' must all be supplied") + errNotTrusted = fmt.Errorf("untrusted server") ) +// ErrMissingParameter is the error raised if a request for 3PID invite has +// an incomplete body +type ErrMissingParameter struct{} + +func (e ErrMissingParameter) Error() string { + return errMissingParameter.Error() +} + +func (e ErrMissingParameter) Unwrap() error { + return errMissingParameter +} + +// ErrNotTrusted is the error raised if an identity server isn't in the list +// of trusted servers in the configuration file. +type ErrNotTrusted struct{} + +func (e ErrNotTrusted) Error() string { + return errNotTrusted.Error() +} + +func (e ErrNotTrusted) Unwrap() error { + return errNotTrusted +} + // CheckAndProcessInvite analyses the body of an incoming membership request. // If the fields relative to a third-party-invite are all supplied, lookups the // matching Matrix ID from the given identity server. If no Matrix ID is @@ -97,7 +119,7 @@ func CheckAndProcessInvite( } else if body.Address == "" || body.IDServer == "" || body.Medium == "" { // If at least one of the 3PID-specific fields is supplied but not all // of them, return an error - err = ErrMissingParameter + err = ErrMissingParameter{} return } @@ -278,7 +300,7 @@ func queryIDServerPubKey(ctx context.Context, idServerName string, keyID string) } var pubKeyRes struct { - PublicKey gomatrixserverlib.Base64Bytes `json:"public_key"` + PublicKey spec.Base64Bytes `json:"public_key"` } if resp.StatusCode != http.StatusOK { @@ -333,8 +355,20 @@ func emit3PIDInviteEvent( rsAPI api.ClientRoomserverAPI, evTime time.Time, ) error { - builder := &gomatrixserverlib.EventBuilder{ - Sender: device.UserID, + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return err + } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + sender, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID) + if err != nil { + return err + } + proto := &gomatrixserverlib.ProtoEvent{ + SenderID: string(sender), RoomID: roomID, Type: "m.room.third_party_invite", StateKey: &res.Token, @@ -348,7 +382,7 @@ func emit3PIDInviteEvent( PublicKeys: res.PublicKeys, } - if err := builder.SetContent(content); err != nil { + if err = proto.SetContent(content); err != nil { return err } @@ -358,7 +392,7 @@ func emit3PIDInviteEvent( } queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(ctx, proto, identity, evTime, rsAPI, &queryRes) if err != nil { return err } @@ -366,8 +400,8 @@ func emit3PIDInviteEvent( return api.SendEvents( ctx, rsAPI, api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - event.Headered(queryRes.RoomVersion), + []*types.HeaderedEvent{ + event, }, device.UserDomain(), cfg.Matrix.ServerName, diff --git a/clientapi/threepid/threepid.go b/clientapi/threepid/threepid.go index d819d9ddf3..d9249d1509 100644 --- a/clientapi/threepid/threepid.go +++ b/clientapi/threepid/threepid.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) // EmailAssociationRequest represents the request defined at https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-register-email-requesttoken @@ -138,7 +139,7 @@ func CheckAssociation( return false, "", "", err } - if respBody.ErrCode == "M_SESSION_NOT_VALIDATED" { + if respBody.ErrCode == string(spec.ErrorSessionNotValidated) { return false, "", "", nil } else if len(respBody.ErrCode) > 0 { return false, "", "", errors.New(respBody.Error) @@ -191,5 +192,5 @@ func isTrusted(idServer string, cfg *config.ClientAPI) error { return nil } } - return ErrNotTrusted + return ErrNotTrusted{} } diff --git a/clientapi/userutil/userutil.go b/clientapi/userutil/userutil.go index 9be1e9b315..26237142b2 100644 --- a/clientapi/userutil/userutil.go +++ b/clientapi/userutil/userutil.go @@ -19,13 +19,14 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // ParseUsernameParam extracts localpart from usernameParam. // usernameParam can either be a user ID or just the localpart/username. // If serverName is passed, it is verified against the domain obtained from usernameParam (if present) // Returns error in case of invalid usernameParam. -func ParseUsernameParam(usernameParam string, cfg *config.Global) (string, gomatrixserverlib.ServerName, error) { +func ParseUsernameParam(usernameParam string, cfg *config.Global) (string, spec.ServerName, error) { localpart := usernameParam if strings.HasPrefix(usernameParam, "@") { @@ -45,6 +46,6 @@ func ParseUsernameParam(usernameParam string, cfg *config.Global) (string, gomat } // MakeUserID generates user ID from localpart & server name -func MakeUserID(localpart string, server gomatrixserverlib.ServerName) string { +func MakeUserID(localpart string, server spec.ServerName) string { return fmt.Sprintf("@%s:%s", localpart, string(server)) } diff --git a/clientapi/userutil/userutil_test.go b/clientapi/userutil/userutil_test.go index 8910983bc4..cdda0a88a4 100644 --- a/clientapi/userutil/userutil_test.go +++ b/clientapi/userutil/userutil_test.go @@ -16,16 +16,16 @@ import ( "testing" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) var ( - localpart = "somelocalpart" - serverName gomatrixserverlib.ServerName = "someservername" - invalidServerName gomatrixserverlib.ServerName = "invalidservername" - goodUserID = "@" + localpart + ":" + string(serverName) - badUserID = "@bad:user:name@noservername:" + localpart = "somelocalpart" + serverName spec.ServerName = "someservername" + invalidServerName spec.ServerName = "invalidservername" + goodUserID = "@" + localpart + ":" + string(serverName) + badUserID = "@bad:user:name@noservername:" ) // TestGoodUserID checks that correct localpart is returned for a valid user ID. diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 9a195990cf..25c1475cbe 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -33,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice" @@ -145,7 +146,7 @@ func main() { } } - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + cfg.Global.ServerName = spec.ServerName(hex.EncodeToString(pk)) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) configErrors := &config.ConfigErrors{} diff --git a/cmd/dendrite-demo-yggdrasil/signing/fetcher.go b/cmd/dendrite-demo-yggdrasil/signing/fetcher.go index bcec0cbec4..aeaa2022ed 100644 --- a/cmd/dendrite-demo-yggdrasil/signing/fetcher.go +++ b/cmd/dendrite-demo-yggdrasil/signing/fetcher.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const KeyID = "ed25519:dendrite-demo-yggdrasil" @@ -36,7 +37,7 @@ func (f *YggdrasilKeys) KeyRing() *gomatrixserverlib.KeyRing { func (f *YggdrasilKeys) FetchKeys( ctx context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { res := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) for req := range requests { @@ -54,7 +55,7 @@ func (f *YggdrasilKeys) FetchKeys( Key: hexkey, }, ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, - ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(24 * time.Hour * 365)), + ValidUntilTS: spec.AsTimestamp(time.Now().Add(24 * time.Hour * 365)), } } return res, nil diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/client.go b/cmd/dendrite-demo-yggdrasil/yggconn/client.go index c25acf2ec7..e1dc0f6681 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/client.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/client.go @@ -38,7 +38,7 @@ func (n *Node) CreateClient() *fclient.Client { func (n *Node) CreateFederationClient( cfg *config.Dendrite, -) *fclient.FederationClient { +) fclient.FederationClient { tr := &http.Transport{} tr.RegisterProtocol( "matrix", &yggroundtripper{ diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/node.go b/cmd/dendrite-demo-yggdrasil/yggconn/node.go index 6df5fa879f..26c30e4892 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/node.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/node.go @@ -23,7 +23,7 @@ import ( "regexp" "strings" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/neilalexander/utp" "github.com/sirupsen/logrus" @@ -134,14 +134,14 @@ func (n *Node) PeerCount() int { return len(n.core.GetPeers()) } -func (n *Node) KnownNodes() []gomatrixserverlib.ServerName { +func (n *Node) KnownNodes() []spec.ServerName { nodemap := map[string]struct{}{} for _, peer := range n.core.GetPeers() { nodemap[hex.EncodeToString(peer.Key)] = struct{}{} } - var nodes []gomatrixserverlib.ServerName + var nodes []spec.ServerName for node := range nodemap { - nodes = append(nodes, gomatrixserverlib.ServerName(node)) + nodes = append(nodes, spec.ServerName(node)) } return nodes } diff --git a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go index 180990d54d..7ebecb651d 100644 --- a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go +++ b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go @@ -21,19 +21,19 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/yggconn" "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) type YggdrasilRoomProvider struct { node *yggconn.Node fedSender api.FederationInternalAPI - fedClient *fclient.FederationClient + fedClient fclient.FederationClient } func NewYggdrasilRoomProvider( - node *yggconn.Node, fedSender api.FederationInternalAPI, fedClient *fclient.FederationClient, + node *yggconn.Node, fedSender api.FederationInternalAPI, fedClient fclient.FederationClient, ) *YggdrasilRoomProvider { p := &YggdrasilRoomProvider{ node: node, @@ -46,7 +46,7 @@ func NewYggdrasilRoomProvider( func (p *YggdrasilRoomProvider) Rooms() []fclient.PublicRoom { return bulkFetchPublicRoomsFromServers( context.Background(), p.fedClient, - gomatrixserverlib.ServerName(p.node.DerivedServerName()), + spec.ServerName(p.node.DerivedServerName()), p.node.KnownNodes(), ) } @@ -54,9 +54,9 @@ func (p *YggdrasilRoomProvider) Rooms() []fclient.PublicRoom { // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( - ctx context.Context, fedClient *fclient.FederationClient, - origin gomatrixserverlib.ServerName, - homeservers []gomatrixserverlib.ServerName, + ctx context.Context, fedClient fclient.FederationClient, + origin spec.ServerName, + homeservers []spec.ServerName, ) (publicRooms []fclient.PublicRoom) { limit := 200 // follow pipeline semantics, see https://blog.golang.org/pipelines for more info. @@ -69,7 +69,7 @@ func bulkFetchPublicRoomsFromServers( wg.Add(len(homeservers)) // concurrently query for public rooms for _, hs := range homeservers { - go func(homeserverDomain gomatrixserverlib.ServerName) { + go func(homeserverDomain spec.ServerName) { defer wg.Done() util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") fres, err := fedClient.GetPublicRooms(ctx, origin, homeserverDomain, int(limit), "", false, "") diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index b71f1f3e45..dcc45bdcc0 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -72,7 +72,7 @@ RUN ./generate-config --ci > dendrite.yaml RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key # Replace the connection string with a single postgres DB, using user/db = 'postgres' and no password -RUN sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml +RUN sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml # No password when connecting over localhost RUN sed -i "s%127.0.0.1/32 md5%127.0.0.1/32 trust%g" /etc/postgresql/11/main/pg_hba.conf # Bump up max conns for moar concurrency @@ -119,7 +119,7 @@ RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key # Make sure the SQLite databases are in a persistent location, we're already mapping # the postgresql folder so let's just use that for simplicity -RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postgresql\/11\/main\/%g" dendrite.yaml +RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postgresql\/11\/main\/%g" dendrite.yaml # This entry script starts postgres, waits for it to be up then starts dendrite RUN echo '\ diff --git a/cmd/dendrite-upgrade-tests/tests.go b/cmd/dendrite-upgrade-tests/tests.go index 03438bd4d5..692ab34ef7 100644 --- a/cmd/dendrite-upgrade-tests/tests.go +++ b/cmd/dendrite-upgrade-tests/tests.go @@ -9,6 +9,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const userPassword = "this_is_a_long_password" @@ -56,7 +57,7 @@ func runTests(baseURL string, v *semver.Version) error { // create DM room, join it and exchange messages createRoomResp, err := users[0].client.CreateRoom(&gomatrix.ReqCreateRoom{ - Preset: "trusted_private_chat", + Preset: spec.PresetTrustedPrivateChat, Invite: []string{users[1].userID}, IsDirect: true, }) @@ -98,7 +99,7 @@ func runTests(baseURL string, v *semver.Version) error { publicRoomID := "" createRoomResp, err = users[0].client.CreateRoom(&gomatrix.ReqCreateRoom{ RoomAliasName: "global", - Preset: "public_chat", + Preset: spec.PresetPublicChat, }) if err != nil { // this is okay and expected if the room already exists and the aliases clash // try to join it diff --git a/cmd/furl/main.go b/cmd/furl/main.go index 32e9970495..cdfef09f73 100644 --- a/cmd/furl/main.go +++ b/cmd/furl/main.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) var requestFrom = flag.String("from", "", "the server name that the request should originate from") @@ -49,7 +50,7 @@ func main() { panic("unexpected key block") } - serverName := gomatrixserverlib.ServerName(*requestFrom) + serverName := spec.ServerName(*requestFrom) client := fclient.NewFederationClient( []*fclient.SigningIdentity{ { @@ -83,10 +84,10 @@ func main() { } } - req := gomatrixserverlib.NewFederationRequest( + req := fclient.NewFederationRequest( method, serverName, - gomatrixserverlib.ServerName(u.Host), + spec.ServerName(u.Host), u.RequestURI(), ) @@ -97,7 +98,7 @@ func main() { } if err = req.Sign( - gomatrixserverlib.ServerName(*requestFrom), + spec.ServerName(*requestFrom), gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), privateKey, ); err != nil { diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 86b302346e..b7ac5751e6 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -5,11 +5,11 @@ import ( "fmt" "path/filepath" - "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" "gopkg.in/yaml.v2" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/spec" ) func main() { @@ -30,7 +30,7 @@ func main() { SingleDatabase: true, }) if *serverName != "" { - cfg.Global.ServerName = gomatrixserverlib.ServerName(*serverName) + cfg.Global.ServerName = spec.ServerName(*serverName) } uri := config.DataSource(*dbURI) if uri.IsSQLite() || uri == "" { diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 09c0e69079..3ffcac9e6d 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -11,13 +11,16 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // This is a utility for inspecting state snapshots and running state resolution @@ -65,10 +68,14 @@ func main() { panic(err) } + natsInstance := &jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, + natsInstance, caching.NewRistrettoCache(128*1024*1024, time.Hour, true), false) + roomInfo := &types.RoomInfo{ RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), } - stateres := state.NewStateResolution(roomserverDB, roomInfo) + stateres := state.NewStateResolution(roomserverDB, roomInfo, rsAPI) if *difference { if len(snapshotNIDs) != 2 { @@ -91,14 +98,14 @@ func main() { } var eventEntries []types.Event - eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs) + eventEntries, err = roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs) if err != nil { panic(err) } - events := make(map[types.EventNID]*gomatrixserverlib.Event, len(eventEntries)) + events := make(map[types.EventNID]gomatrixserverlib.PDU, len(eventEntries)) for _, entry := range eventEntries { - events[entry.EventNID] = entry.Event + events[entry.EventNID] = entry.PDU } if len(removed) > 0 { @@ -149,15 +156,15 @@ func main() { } fmt.Println("Fetching", len(eventNIDMap), "state events") - eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs) + eventEntries, err := roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs) if err != nil { panic(err) } authEventIDMap := make(map[string]struct{}) - events := make([]*gomatrixserverlib.Event, len(eventEntries)) + events := make([]gomatrixserverlib.PDU, len(eventEntries)) for i := range eventEntries { - events[i] = eventEntries[i].Event + events[i] = eventEntries[i].PDU for _, authEventID := range eventEntries[i].AuthEventIDs() { authEventIDMap[authEventID] = struct{}{} } @@ -174,17 +181,17 @@ func main() { panic(err) } - authEvents := make([]*gomatrixserverlib.Event, len(authEventEntries)) + authEvents := make([]gomatrixserverlib.PDU, len(authEventEntries)) for i := range authEventEntries { - authEvents[i] = authEventEntries[i].Event + authEvents[i] = authEventEntries[i].PDU } fmt.Println("Resolving state") var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( - gomatrixserverlib.RoomVersion(*roomVersion), - events, - authEvents, + gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { panic(err) @@ -208,7 +215,7 @@ func main() { fmt.Println("Returned", count, "state events after filtering") } -type Events []*gomatrixserverlib.Event +type Events []gomatrixserverlib.PDU func (e Events) Len() int { return len(e) diff --git a/dendrite-sample.yaml b/dendrite-sample.yaml index 6b3ea74f22..96143d85f2 100644 --- a/dendrite-sample.yaml +++ b/dendrite-sample.yaml @@ -69,8 +69,7 @@ global: # e.g. localhost:443 well_known_server_name: "" - # The server name to delegate client-server communications to, with optional port - # e.g. localhost:443 + # The base URL to delegate client-server communications to e.g. https://localhost well_known_client_name: "" # Lists of domains that the server will trust as identity servers to verify third diff --git a/docs/FAQ.md b/docs/FAQ.md index 2000207265..757bf96255 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -24,7 +24,7 @@ No, although a good portion of the Matrix specification has been implemented. Mo Dendrite development is currently supported by a small team of developers and due to those limited resources, the majority of the effort is focused on getting Dendrite to be specification complete. If there are major features you're requesting (e.g. new administration endpoints), we'd like to strongly encourage you to join the community in supporting -the development efforts through [contributing](https://matrix-org.github.io/dendrite/development/contributing). +the development efforts through [contributing](../development/contributing). ## Is there a migration path from Synapse to Dendrite? @@ -103,7 +103,7 @@ This can be done by performing a room upgrade. Use the command `/upgraderoom 0.1) @@ -231,9 +231,9 @@ GEM jekyll-seo-tag (~> 2.1) minitest (5.17.0) multipart-post (2.1.1) - nokogiri (1.13.10-arm64-darwin) + nokogiri (1.14.3-arm64-darwin) racc (~> 1.4) - nokogiri (1.13.10-x86_64-linux) + nokogiri (1.14.3-x86_64-linux) racc (~> 1.4) octokit (4.22.0) faraday (>= 0.9) @@ -241,7 +241,7 @@ GEM pathutil (0.16.2) forwardable-extended (~> 2.6) public_suffix (4.0.7) - racc (1.6.1) + racc (1.6.2) rb-fsevent (0.11.1) rb-inotify (0.10.1) ffi (~> 1.0) diff --git a/docs/INSTALL.md b/docs/INSTALL.md index ccfc58107a..8e72da9715 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -6,8 +6,8 @@ or alternatively, in the [installation](installation/) folder: 1. [Planning your deployment](installation/1_planning.md) 2. [Setting up the domain](installation/2_domainname.md) -3. [Preparing database storage](installation/3_database.md) -4. [Generating signing keys](installation/4_signingkey.md) -5. [Installing as a monolith](installation/5_install_monolith.md) -6. [Populate the configuration](installation/7_configuration.md) -7. [Starting the monolith](installation/8_starting_monolith.md) +3. [Installing Dendrite](installation/manual/1_build.md) +4. [Preparing database storage](installation/manual/2_database.md) +5. [Populate the configuration](installation/manual/3_configuration.md) +6. [Generating signing keys](installation/manual/4_signingkey.md) +7. [Starting Dendrite](installation/manual/5_starting_dendrite.md) diff --git a/docs/administration/1_createusers.md b/docs/administration/1_createusers.md index 24eba666dc..cbdccd18b7 100644 --- a/docs/administration/1_createusers.md +++ b/docs/administration/1_createusers.md @@ -11,10 +11,9 @@ User accounts can be created on a Dendrite instance in a number of ways. ## From the command line -The `create-account` tool is built in the `bin` folder when building Dendrite with -the `build.sh` script. +The `create-account` tool is built in the `bin` folder when [building](../installation/build) Dendrite. -It uses the `dendrite.yaml` configuration file to connect to a running Dendrite instance and requires +It uses the `dendrite.yaml` configuration file to connect to a **running** Dendrite instance and requires shared secret registration to be enabled as explained below. An example of using `create-account` to create a **normal account**: diff --git a/docs/administration/4_adminapi.md b/docs/administration/4_adminapi.md index b11aeb1a60..6f64589977 100644 --- a/docs/administration/4_adminapi.md +++ b/docs/administration/4_adminapi.md @@ -1,6 +1,7 @@ --- title: Supported admin APIs parent: Administration +nav_order: 4 permalink: /administration/adminapi --- @@ -49,13 +50,17 @@ the room IDs of all affected rooms. ## POST `/_dendrite/admin/resetPassword/{userID}` -Reset the password of a local user. +Reset the password of a local user. + +**If `logout_devices` is set to `true`, all `access_tokens` will be invalidated, resulting +in the potential loss of encrypted messages** Request body format: -``` +```json { - "password": "new_password_here" + "password": "new_password_here", + "logout_devices": false } ``` @@ -68,11 +73,14 @@ Indexing is done in the background, the server logs every 1000 events (or below) This endpoint instructs Dendrite to immediately query `/devices/{userID}` on a federated server. An empty JSON body will be returned on success, updating all locally stored user devices/keys. This can be used to possibly resolve E2EE issues, where the remote user can't decrypt messages. +## POST `/_dendrite/admin/purgeRoom/{roomID}` + +This endpoint instructs Dendrite to remove the given room from its database. Before doing so, it will evacuate all local users from the room. It does **NOT** remove media files. Depending on the size of the room, this may take a while. Will return an empty JSON once other components were instructed to delete the room. ## POST `/_synapse/admin/v1/send_server_notice` Request body format: -``` +```json { "user_id": "@target_user:server_name", "content": { @@ -85,7 +93,7 @@ Request body format: Send a server notice to a specific user. See the [Matrix Spec](https://spec.matrix.org/v1.3/client-server-api/#server-notices) for additional details on server notice behaviour. If successfully sent, the API will return the following response: -``` +```json { "event_id": "" } diff --git a/docs/installation/11_optimisation.md b/docs/administration/5_optimisation.md similarity index 90% rename from docs/installation/11_optimisation.md rename to docs/administration/5_optimisation.md index 686ec2eb9b..b327171ebe 100644 --- a/docs/installation/11_optimisation.md +++ b/docs/administration/5_optimisation.md @@ -1,9 +1,9 @@ --- title: Optimise your installation -parent: Installation +parent: Administration has_toc: true -nav_order: 11 -permalink: /installation/start/optimisation +nav_order: 5 +permalink: /administration/optimisation --- # Optimise your installation @@ -36,11 +36,6 @@ connections it will open to the database. **If you are using the `global` database pool** then you only need to configure the `max_open_conns` setting once in the `global` section. -**If you are defining a `database` config per component** then you will need to ensure that -the **sum total** of all configured `max_open_conns` to a given database server do not exceed -the connection limit. If you configure a total that adds up to more connections than are available -then this will cause database queries to fail. - You may wish to raise the `max_connections` limit on your PostgreSQL server to accommodate additional connections, in which case you should also update the `max_open_conns` in your Dendrite configuration accordingly. However be aware that this is only advisable on particularly diff --git a/docs/administration/5_troubleshooting.md b/docs/administration/6_troubleshooting.md similarity index 88% rename from docs/administration/5_troubleshooting.md rename to docs/administration/6_troubleshooting.md index 8ba510ef61..5f11f99316 100644 --- a/docs/administration/5_troubleshooting.md +++ b/docs/administration/6_troubleshooting.md @@ -1,6 +1,7 @@ --- title: Troubleshooting parent: Administration +nav_order: 6 permalink: /administration/troubleshooting --- @@ -18,7 +19,7 @@ be clues in the logs. You can increase this log level to the more verbose `debug` level if necessary by adding this to the config and restarting Dendrite: -``` +```yaml logging: - type: std level: debug @@ -56,12 +57,7 @@ number of database connections does not exceed the maximum allowed by PostgreSQL Open your `postgresql.conf` configuration file and check the value of `max_connections` (which is typically `100` by default). Then open your `dendrite.yaml` configuration file -and ensure that: - -1. If you are using the `global.database` section, that `max_open_conns` does not exceed - that number; -2. If you are **not** using the `global.database` section, that the sum total of all - `max_open_conns` across all `database` blocks does not exceed that number. +and ensure that in the `global.database` section, `max_open_conns` does not exceed that number. ## 5. File descriptors @@ -77,7 +73,7 @@ If there aren't, you will see a log lines like this: level=warning msg="IMPORTANT: Process file descriptor limit is currently 65535, it is recommended to raise the limit for Dendrite to at least 65535 to avoid issues" ``` -Follow the [Optimisation](../installation/11_optimisation.md) instructions to correct the +Follow the [Optimisation](5_optimisation.md) instructions to correct the available number of file descriptors. ## 6. STUN/TURN Server tester diff --git a/docs/caddy/monolith/Caddyfile b/docs/caddy/Caddyfile similarity index 100% rename from docs/caddy/monolith/Caddyfile rename to docs/caddy/Caddyfile diff --git a/docs/caddy/polylith/Caddyfile b/docs/caddy/polylith/Caddyfile deleted file mode 100644 index c2d81b49bf..0000000000 --- a/docs/caddy/polylith/Caddyfile +++ /dev/null @@ -1,85 +0,0 @@ -# Sample Caddyfile for using Caddy in front of Dendrite - -# - -# Customize email address and domain names - -# Optional settings commented out - -# - -# BE SURE YOUR DOMAINS ARE POINTED AT YOUR SERVER FIRST - -# Documentation: - -# - -# Bonus tip: If your IP address changes, use Caddy's - -# dynamic DNS plugin to update your DNS records to - -# point to your new IP automatically - -# - -# - -# Global options block - -{ - # In case there is a problem with your certificates. - # email example@example.com - - # Turn off the admin endpoint if you don't need graceful config - # changes and/or are running untrusted code on your machine. - # admin off - - # Enable this if your clients don't send ServerName in TLS handshakes. - # default_sni example.com - - # Enable debug mode for verbose logging. - # debug - - # Use Let's Encrypt's staging endpoint for testing. - # acme_ca https://acme-staging-v02.api.letsencrypt.org/directory - - # If you're port-forwarding HTTP/HTTPS ports from 80/443 to something - # else, enable these and put the alternate port numbers here. - # http_port 8080 - # https_port 8443 -} - -# The server name of your matrix homeserver. This example shows - -# "well-known delegation" from the registered domain to a subdomain - -# which is only needed if your server_name doesn't match your Matrix - -# homeserver URL (i.e. you can show users a vanity domain that looks - -# nice and is easy to remember but still have your Matrix server on - -# its own subdomain or hosted service) - -example.com { - header /.well-known/matrix/*Content-Type application/json - header /.well-known/matrix/* Access-Control-Allow-Origin * - respond /.well-known/matrix/server `{"m.server": "matrix.example.com:443"}` - respond /.well-known/matrix/client `{"m.homeserver": {"base_url": "https://matrix.example.com"}}` -} - -# The actual domain name whereby your Matrix server is accessed - -matrix.example.com { - # Change the end of each reverse_proxy line to the correct - # address for your various services. - @sync_api { - path_regexp /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ - } - reverse_proxy @sync_api sync_api:8073 - - reverse_proxy /_matrix/client* client_api:8071 - reverse_proxy /_matrix/federation* federation_api:8071 - reverse_proxy /_matrix/key* federation_api:8071 - reverse_proxy /_matrix/media* media_api:8071 -} diff --git a/docs/development/CONTRIBUTING.md b/docs/development/CONTRIBUTING.md index 2aec4c3631..71e7516a29 100644 --- a/docs/development/CONTRIBUTING.md +++ b/docs/development/CONTRIBUTING.md @@ -1,6 +1,7 @@ --- title: Contributing parent: Development +nav_order: 1 permalink: /development/contributing --- diff --git a/docs/development/PROFILING.md b/docs/development/PROFILING.md index 57c37a9006..dc4eca7b75 100644 --- a/docs/development/PROFILING.md +++ b/docs/development/PROFILING.md @@ -1,6 +1,7 @@ --- title: Profiling parent: Development +nav_order: 4 permalink: /development/profiling --- diff --git a/docs/development/coverage.md b/docs/development/coverage.md index c4a8a11743..1b15f71a2d 100644 --- a/docs/development/coverage.md +++ b/docs/development/coverage.md @@ -1,78 +1,130 @@ --- title: Coverage parent: Development +nav_order: 3 permalink: /development/coverage --- -To generate a test coverage report for Sytest, a small patch needs to be applied to the Sytest repository to compile and use the instrumented binary: -```patch -diff --git a/lib/SyTest/Homeserver/Dendrite.pm b/lib/SyTest/Homeserver/Dendrite.pm -index 8f0e209c..ad057e52 100644 ---- a/lib/SyTest/Homeserver/Dendrite.pm -+++ b/lib/SyTest/Homeserver/Dendrite.pm -@@ -337,7 +337,7 @@ sub _start_monolith - - $output->diag( "Starting monolith server" ); - my @command = ( -- $self->{bindir} . '/dendrite', -+ $self->{bindir} . '/dendrite', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL", - '--config', $self->{paths}{config}, - '--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port, - '--https-bind-address', $self->{bind_host} . ':' . $self->secure_port, -diff --git a/scripts/dendrite_sytest.sh b/scripts/dendrite_sytest.sh -index f009332b..7ea79869 100755 ---- a/scripts/dendrite_sytest.sh -+++ b/scripts/dendrite_sytest.sh -@@ -34,7 +34,8 @@ export GOBIN=/tmp/bin - echo >&2 "--- Building dendrite from source" - cd /src - mkdir -p $GOBIN --go install -v ./cmd/dendrite -+# go install -v ./cmd/dendrite -+go test -c -cover -covermode=atomic -o $GOBIN/dendrite -coverpkg "github.com/matrix-org/..." ./cmd/dendrite - go install -v ./cmd/generate-keys - cd - - ``` +## Running unit tests with coverage enabled + +Running unit tests with coverage enabled can be done with the following commands, this will generate a `integrationcover.log` +```bash +go test -covermode=atomic -coverpkg=./... -coverprofile=integrationcover.log $(go list ./... | grep -v '/cmd/') +go tool cover -func=integrationcover.log +``` + +## Running Sytest with coverage enabled + +To run Sytest with coverage enabled: + +```bash +docker run --rm --name sytest -v "/Users/kegan/github/sytest:/sytest" \ + -v "/Users/kegan/github/dendrite:/src" -v "$(pwd)/sytest_logs:/logs" \ + -v "/Users/kegan/go/:/gopath" -e "POSTGRES=1" \ + -e "COVER=1" \ + matrixdotorg/sytest-dendrite:latest + +# to get a more accurate coverage you may also need to run Sytest using SQLite as the database: +docker run --rm --name sytest -v "/Users/kegan/github/sytest:/sytest" \ + -v "/Users/kegan/github/dendrite:/src" -v "$(pwd)/sytest_logs:/logs" \ + -v "/Users/kegan/go/:/gopath" \ + -e "COVER=1" \ + matrixdotorg/sytest-dendrite:latest +``` + +This will generate a folder `covdatafiles` in each server's directory, e.g `server-0/covdatafiles`. To parse them, +ensure your working directory is under the Dendrite repository then run: - Then run Sytest. This will generate a new file `integrationcover.log` in each server's directory e.g `server-0/integrationcover.log`. To parse it, - ensure your working directory is under the Dendrite repository then run: ```bash - go tool cover -func=/path/to/server-0/integrationcover.log + go tool covdata func -i="$(find -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" ``` which will produce an output like: ``` ... - github.com/matrix-org/util/json.go:83: NewJSONRequestHandler 100.0% -github.com/matrix-org/util/json.go:90: Protect 57.1% -github.com/matrix-org/util/json.go:110: RequestWithLogging 100.0% -github.com/matrix-org/util/json.go:132: MakeJSONAPI 70.0% -github.com/matrix-org/util/json.go:151: respond 61.5% -github.com/matrix-org/util/json.go:180: WithCORSOptions 0.0% -github.com/matrix-org/util/json.go:191: SetCORSHeaders 100.0% -github.com/matrix-org/util/json.go:202: RandomString 100.0% -github.com/matrix-org/util/json.go:210: init 100.0% -github.com/matrix-org/util/unique.go:13: Unique 91.7% -github.com/matrix-org/util/unique.go:48: SortAndUnique 100.0% -github.com/matrix-org/util/unique.go:55: UniqueStrings 100.0% -total: (statements) 53.7% +github.com/matrix-org/util/json.go:132: MakeJSONAPI 70.0% +github.com/matrix-org/util/json.go:151: respond 84.6% +github.com/matrix-org/util/json.go:180: WithCORSOptions 0.0% +github.com/matrix-org/util/json.go:191: SetCORSHeaders 100.0% +github.com/matrix-org/util/json.go:202: RandomString 100.0% +github.com/matrix-org/util/json.go:210: init 100.0% +github.com/matrix-org/util/unique.go:13: Unique 91.7% +github.com/matrix-org/util/unique.go:48: SortAndUnique 100.0% +github.com/matrix-org/util/unique.go:55: UniqueStrings 100.0% +total (statements) 64.0% +``` +(after running Sytest for Postgres _and_ SQLite) + +The total coverage for this run is the last line at the bottom. However, this value is misleading because Dendrite can run in different configurations, +which will never be tested in a single test run (e.g sqlite or postgres). To get a more accurate value, you'll need run Sytest for Postgres and SQLite (see commands above). +Additional processing is required also to remove packages which will never be tested and extension MSCs: + +```bash +# If you executed both commands from above, you can get the total coverage using the following commands +go tool covdata textfmt -i="$(find -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" -o sytest.cov +grep -Ev 'relayapi|setup/mscs' sytest.cov > final.cov +go tool cover -func=final.cov + +# If you only executed the one for Postgres: +go tool covdata textfmt -i="$(find -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" -o sytest.cov +grep -Ev 'relayapi|sqlite|setup/mscs' sytest.cov > final.cov +go tool cover -func=final.cov + +# If you only executed the one for SQLite: +go tool covdata textfmt -i="$(find -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" -o sytest.cov +grep -Ev 'relayapi|postgres|setup/mscs' sytest.cov > final.cov +go tool cover -func=final.cov ``` -The total coverage for this run is the last line at the bottom. However, this value is misleading because Dendrite can run in many different configurations, -which will never be tested in a single test run (e.g sqlite or postgres). To get a more accurate value, additional processing is required -to remove packages which will never be tested and extension MSCs: + +## Getting coverage from Complement + +Getting the coverage for Complement runs is a bit more involved. + +First you'll need a docker image compatible with Complement, one can be built using +```bash +docker build -t complement-dendrite -f build/scripts/Complement.Dockerfile . +``` +from within the Dendrite repository. + +Clone complement to a directory of your liking: ```bash -# These commands are all similar but change which package paths are _removed_ from the output. +git clone https://github.com/matrix-org/complement.git +cd complement +``` -# For Postgres -go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|sqlite|setup/mscs|api_trace' > coverage.txt +Next we'll need a script to execute after a test finishes, create a new file `posttest.sh`, make the file executable (`chmod +x posttest.sh`) +and add the following content: +```bash +#!/bin/bash -# For SQLite -go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|postgres|setup/mscs|api_trace' > coverage.txt +mkdir -p /tmp/Complement/logs/$2/$1/ +docker cp $1:/tmp/covdatafiles/. /tmp/Complement/logs/$2/$1/ ``` +This will copy the `covdatafiles` files from each container to something like +`/tmp/Complement/logs/TestLogin/94f9c428de95779d2b62a3ccd8eab9d5ddcf65cc259a40ece06bdc61687ffed3/`. (`$1` is the containerID, `$2` the test name) -A total value can then be calculated using: +Now that we have set up everything we need, we can finally execute Complement: ```bash -cat coverage.txt | awk -F '\t+' '{x = x + $3} END {print x/NR}' +COMPLEMENT_BASE_IMAGE=complement-dendrite \ +COMPLEMENT_SHARE_ENV_PREFIX=COMPLEMENT_DENDRITE_ \ +COMPLEMENT_DENDRITE_COVER=1 \ +COMPLEMENT_POST_TEST_SCRIPT=$(pwd)/posttest.sh \ + go test -tags dendrite_blacklist ./tests/... -count=1 -v -timeout=30m -failfast=false ``` +Once this is done, you can copy the resulting `covdatafiles` files to your Dendrite repository for the next step. +```bash +cp -pr /tmp/Complement/logs PathToYourDendriteRepository +``` -We currently do not have a way to combine Sytest/Complement/Unit Tests into a single coverage report. \ No newline at end of file +You can also run the following to get the coverage for Complement runs alone: +```bash +go tool covdata func -i="$(find /tmp/Complement -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" +``` + +## Combining the results of (almost) all runs + +Now that we have all our `covdatafiles` files within the Dendrite repository, you can now execute the following command, to get the coverage +overall (excluding unit tests): +```bash +go tool covdata func -i="$(find -name 'covmeta*' -type f -exec dirname {} \; | uniq | paste -s -d ',' -)" +``` \ No newline at end of file diff --git a/docs/development/sytest.md b/docs/development/sytest.md index 4fae2ea3d1..2f681f3e5b 100644 --- a/docs/development/sytest.md +++ b/docs/development/sytest.md @@ -1,6 +1,7 @@ --- title: SyTest parent: Development +nav_order: 2 permalink: /development/sytest --- @@ -23,7 +24,7 @@ After running the tests, a script will print the tests you need to add to You should proceed after you see no build problems for dendrite after running: ```sh -./build.sh +go build -o bin/ ./cmd/... ``` If you are fixing an issue marked with @@ -61,6 +62,8 @@ When debugging, the following Docker `run` options may also be useful: * `-e "DENDRITE_TRACE_HTTP=1"`: Adds HTTP tracing to server logs. * `-e "DENDRITE_TRACE_INTERNAL=1"`: Adds roomserver internal API tracing to server logs. +* `-e "COVER=1"`: Run Sytest with an instrumented binary, producing a Go coverage file per server. +* `-e "RACE_DETECTION=1"`: Build the binaries with the `-race` flag (Note: This will significantly slow down test runs) The docker command also supports a single positional argument for the test file to run, so you can run a single `.pl` file rather than the whole test suite. For example: @@ -71,68 +74,3 @@ docker run --rm --name sytest -v "/Users/kegan/github/sytest:/sytest" -v "/Users/kegan/go/:/gopath" -e "POSTGRES=1" -e "DENDRITE_TRACE_HTTP=1" matrixdotorg/sytest-dendrite:latest tests/50federation/40devicelists.pl ``` - -### Manually Setting up SyTest - -**We advise AGAINST using manual SyTest setups.** - -If you don't want to use the Docker image, you can also run SyTest by hand. Make -sure you have Perl 5 or above, and get SyTest with: - -(Note that this guide assumes your SyTest checkout is next to your -`dendrite` checkout.) - -```sh -git clone -b develop https://github.com/matrix-org/sytest -cd sytest -./install-deps.pl -``` - -Set up the database: - -```sh -sudo -u postgres psql -c "CREATE USER dendrite PASSWORD 'itsasecret'" -sudo -u postgres psql -c "ALTER USER dendrite CREATEDB" -for i in dendrite0 dendrite1 sytest_template; do sudo -u postgres psql -c "CREATE DATABASE $i OWNER dendrite;"; done -mkdir -p "server-0" -cat > "server-0/database.yaml" << EOF -args: - user: dendrite - password: itsasecret - database: dendrite0 - host: 127.0.0.1 - sslmode: disable -type: pg -EOF -mkdir -p "server-1" -cat > "server-1/database.yaml" << EOF -args: - user: dendrite - password: itsasecret - database: dendrite1 - host: 127.0.0.1 - sslmode: disable -type: pg -EOF -``` - -Run the tests: - -```sh -POSTGRES=1 ./run-tests.pl -I Dendrite::Monolith -d ../dendrite/bin -W ../dendrite/sytest-whitelist -O tap --all | tee results.tap -``` - -where `tee` lets you see the results while they're being piped to the file, and -`POSTGRES=1` enables testing with PostgeSQL. If the `POSTGRES` environment -variable is not set or is set to 0, SyTest will fall back to SQLite 3. For more -flags and options, see . - -Once the tests are complete, run the helper script to see if you need to add -any newly passing test names to `sytest-whitelist` in the project's root -directory: - -```sh -../dendrite/show-expected-fail-tests.sh results.tap ../dendrite/sytest-whitelist ../dendrite/sytest-blacklist -``` - -If the script prints nothing/exits with 0, then you're good to go. diff --git a/docs/development/tracing/opentracing.md b/docs/development/tracing/opentracing.md deleted file mode 100644 index 8528c2ba3f..0000000000 --- a/docs/development/tracing/opentracing.md +++ /dev/null @@ -1,114 +0,0 @@ ---- -title: OpenTracing -has_children: true -parent: Development -permalink: /development/opentracing ---- - -# OpenTracing - -Dendrite extensively uses the [opentracing.io](http://opentracing.io) framework -to trace work across the different logical components. - -At its most basic opentracing tracks "spans" of work; recording start and end -times as well as any parent span that caused the piece of work. - -A typical example would be a new span being created on an incoming request that -finishes when the response is sent. When the code needs to hit out to a -different component a new span is created with the initial span as its parent. -This would end up looking roughly like: - -``` -Received request Sent response - |<───────────────────────────────────────>| - |<────────────────────>| - RPC call RPC call returns -``` - -This is useful to see where the time is being spent processing a request on a -component. However, opentracing allows tracking of spans across components. This -makes it possible to see exactly what work goes into processing a request: - -``` -Component 1 |<─────────────────── HTTP ────────────────────>| - |<──────────────── RPC ─────────────────>| -Component 2 |<─ SQL ─>| |<── RPC ───>| -Component 3 |<─ SQL ─>| -``` - -This is achieved by serializing span information during all communication -between components. For HTTP requests, this is achieved by the sender -serializing the span into a HTTP header, and the receiver deserializing the span -on receipt. (Generally a new span is then immediately created with the -deserialized span as the parent). - -A collection of spans that are related is called a trace. - -Spans are passed through the code via contexts, rather than manually. It is -therefore important that all spans that are created are immediately added to the -current context. Thankfully the opentracing library gives helper functions for -doing this: - -```golang -span, ctx := opentracing.StartSpanFromContext(ctx, spanName) -defer span.Finish() -``` - -This will create a new span, adding any span already in `ctx` as a parent to the -new span. - -Adding Information ------------------- - -Opentracing allows adding information to a trace via three mechanisms: - -- "tags" ─ A span can be tagged with a key/value pair. This is typically - information that relates to the span, e.g. for spans created for incoming HTTP - requests could include the request path and response codes as tags, spans for - SQL could include the query being executed. -- "logs" ─ Key/value pairs can be looged at a particular instance in a trace. - This can be useful to log e.g. any errors that happen. -- "baggage" ─ Arbitrary key/value pairs can be added to a span to which all - child spans have access. Baggage isn't saved and so isn't available when - inspecting the traces, but can be used to add context to logs or tags in child - spans. - -See -[specification.md](https://github.com/opentracing/specification/blob/master/specification.md) -for some of the common tags and log fields used. - -Span Relationships ------------------- - -Spans can be related to each other. The most common relation is `childOf`, which -indicates the child span somehow depends on the parent span ─ typically the -parent span cannot complete until all child spans are completed. - -A second relation type is `followsFrom`, where the parent has no dependence on -the child span. This usually indicates some sort of fire and forget behaviour, -e.g. adding a message to a pipeline or inserting into a kafka topic. - -Jaeger ------- - -Opentracing is just a framework. We use -[jaeger](https://github.com/jaegertracing/jaeger) as the actual implementation. - -Jaeger is responsible for recording, sending and saving traces, as well as -giving a UI for viewing and interacting with traces. - -To enable jaeger a `Tracer` object must be instansiated from the config (as well -as having a jaeger server running somewhere, usually locally). A `Tracer` does -several things: - -- Decides which traces to save and send to the server. There are multiple - schemes for doing this, with a simple example being to save a certain fraction - of traces. -- Communicating with the jaeger backend. If not explicitly specified uses the - default port on localhost. -- Associates a service name to all spans created by the tracer. This service - name equates to a logical component, e.g. spans created by clientapi will have - a different service name than ones created by the syncapi. Database access - will also typically use a different service name. - - This means that there is a tracer per service name/component. diff --git a/docs/development/tracing/setup.md b/docs/development/tracing/setup.md deleted file mode 100644 index cef1089e46..0000000000 --- a/docs/development/tracing/setup.md +++ /dev/null @@ -1,57 +0,0 @@ ---- -title: Setup -parent: OpenTracing -grand_parent: Development -permalink: /development/opentracing/setup ---- - -# OpenTracing Setup - -Dendrite uses [Jaeger](https://www.jaegertracing.io/) for tracing between microservices. -Tracing shows the nesting of logical spans which provides visibility on how the microservices interact. -This document explains how to set up Jaeger locally on a single machine. - -## Set up the Jaeger backend - -The [easiest way](https://www.jaegertracing.io/docs/1.18/getting-started/) is to use the all-in-one Docker image: - -``` -$ docker run -d --name jaeger \ - -e COLLECTOR_ZIPKIN_HTTP_PORT=9411 \ - -p 5775:5775/udp \ - -p 6831:6831/udp \ - -p 6832:6832/udp \ - -p 5778:5778 \ - -p 16686:16686 \ - -p 14268:14268 \ - -p 14250:14250 \ - -p 9411:9411 \ - jaegertracing/all-in-one:1.18 -``` - -## Configuring Dendrite to talk to Jaeger - -Modify your config to look like: (this will send every single span to Jaeger which will be slow on large instances, but for local testing it's fine) - -``` -tracing: - enabled: true - jaeger: - serviceName: "dendrite" - disabled: false - rpc_metrics: true - tags: [] - sampler: - type: const - param: 1 -``` - -then run the monolith server: - -``` -./dendrite --tls-cert server.crt --tls-key server.key --config dendrite.yaml -``` - -## Checking traces - -Visit to see traces under `DendriteMonolith`. diff --git a/docs/hiawatha/monolith-sample.conf b/docs/hiawatha/dendrite-sample.conf similarity index 100% rename from docs/hiawatha/monolith-sample.conf rename to docs/hiawatha/dendrite-sample.conf diff --git a/docs/hiawatha/polylith-sample.conf b/docs/hiawatha/polylith-sample.conf deleted file mode 100644 index eb1dd4f9ae..0000000000 --- a/docs/hiawatha/polylith-sample.conf +++ /dev/null @@ -1,35 +0,0 @@ -# Depending on which port is used for federation (.well-known/matrix/server or SRV record), -# ensure there's a binding for that port in the configuration. Replace "FEDPORT" with port -# number, (e.g. "8448"), and "IPV4" with your server's ipv4 address (separate binding for -# each ip address, e.g. if you use both ipv4 and ipv6 addresses). - -Binding { - Port = FEDPORT - Interface = IPV4 - TLScertFile = /path/to/fullchainandprivkey.pem -} - - -VirtualHost { - ... - # route requests to: - # /_matrix/client/.*/sync - # /_matrix/client/.*/user/{userId}/filter - # /_matrix/client/.*/user/{userId}/filter/{filterID} - # /_matrix/client/.*/keys/changes - # /_matrix/client/.*/rooms/{roomId}/messages - # /_matrix/client/.*/rooms/{roomId}/context/{eventID} - # /_matrix/client/.*/rooms/{roomId}/event/{eventID} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}/{eventType} - # /_matrix/client/.*/rooms/{roomId}/members - # /_matrix/client/.*/rooms/{roomId}/joined_members - # to sync_api - ReverseProxy = /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ http://localhost:8073 600 - ReverseProxy = /_matrix/client http://localhost:8071 600 - ReverseProxy = /_matrix/federation http://localhost:8072 600 - ReverseProxy = /_matrix/key http://localhost:8072 600 - ReverseProxy = /_matrix/media http://localhost:8074 600 - ... -} diff --git a/docs/installation/1_planning.md b/docs/installation/1_planning.md index 36d90abdaa..354003aef3 100644 --- a/docs/installation/1_planning.md +++ b/docs/installation/1_planning.md @@ -7,23 +7,13 @@ permalink: /installation/planning # Planning your installation -## Modes - -Dendrite consists of several components, each responsible for a different aspect of the Matrix protocol. -Users can run Dendrite in one of two modes which dictate how these components are executed and communicate. - -* **Monolith mode** runs all components in a single process. Components communicate through an internal NATS - server with generally low overhead. This mode dramatically simplifies deployment complexity and offers the - best balance between performance and resource usage for low-to-mid volume deployments. - - -## Databases +## Database Dendrite can run with either a PostgreSQL or a SQLite backend. There are considerable tradeoffs to consider: * **PostgreSQL**: Needs to run separately to Dendrite, needs to be installed and configured separately - and and will use more resources over all, but will be **considerably faster** than SQLite. PostgreSQL + and will use more resources over all, but will be **considerably faster** than SQLite. PostgreSQL has much better write concurrency which will allow Dendrite to process more tasks in parallel. This will be necessary for federated deployments to perform adequately. @@ -80,18 +70,17 @@ If using the PostgreSQL database engine, you should install PostgreSQL 12 or lat ### NATS Server Dendrite comes with a built-in [NATS Server](https://github.com/nats-io/nats-server) and -therefore does not need this to be manually installed. If you are planning a monolith installation, you -do not need to do anything. +therefore does not need this to be manually installed. ### Reverse proxy A reverse proxy such as [Caddy](https://caddyserver.com), [NGINX](https://www.nginx.com) or -[HAProxy](http://www.haproxy.org) is useful for deployments. Configuring those is not covered in this documentation, although sample configurations +[HAProxy](http://www.haproxy.org) is useful for deployments. Configuring this is not covered in this documentation, although sample configurations for [Caddy](https://github.com/matrix-org/dendrite/blob/main/docs/caddy) and [NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx) are provided. ### Windows -Finally, if you want to build Dendrite on Windows, you will need need `gcc` in the path. The best +Finally, if you want to build Dendrite on Windows, you will need `gcc` in the path. The best way to achieve this is by installing and building Dendrite under [MinGW-w64](https://www.mingw-w64.org/). diff --git a/docs/installation/2_domainname.md b/docs/installation/2_domainname.md index 545a2daf6e..d86a664cb1 100644 --- a/docs/installation/2_domainname.md +++ b/docs/installation/2_domainname.md @@ -20,7 +20,7 @@ Matrix servers usually discover each other when federating using the following m well-known file to connect to the remote homeserver; 2. If a DNS SRV delegation exists on `example.com`, use the IP address and port from the DNS SRV record to connect to the remote homeserver; -3. If neither well-known or DNS SRV delegation are configured, attempt to connect to the remote +3. If neither well-known nor DNS SRV delegation are configured, attempt to connect to the remote homeserver by connecting to `example.com` port TCP/8448 using HTTPS. The exact details of how server name resolution works can be found in diff --git a/docs/installation/5_install_monolith.md b/docs/installation/5_install_monolith.md deleted file mode 100644 index 901975a654..0000000000 --- a/docs/installation/5_install_monolith.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -title: Installing as a monolith -parent: Installation -has_toc: true -nav_order: 5 -permalink: /installation/install/monolith ---- - -# Installing as a monolith - -You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`: - -```sh -go install ./cmd/dendrite -``` - -Alternatively, you can specify a custom path for the binary to be written to using `go build`: - -```sh -go build -o /usr/local/bin/ ./cmd/dendrite -``` diff --git a/docs/installation/9_starting_monolith.md b/docs/installation/9_starting_monolith.md deleted file mode 100644 index d7e8c0b8ba..0000000000 --- a/docs/installation/9_starting_monolith.md +++ /dev/null @@ -1,42 +0,0 @@ ---- -title: Starting the monolith -parent: Installation -has_toc: true -nav_order: 9 -permalink: /installation/start/monolith ---- - -# Starting the monolith - -Once you have completed all of the preparation and installation steps, -you can start your Dendrite monolith deployment by starting `dendrite`: - -```bash -./dendrite -config /path/to/dendrite.yaml -``` - -By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses -or ports that Dendrite listens on, you can use the `-http-bind-address` and -`-https-bind-address` command line arguments: - -```bash -./dendrite -config /path/to/dendrite.yaml \ - -http-bind-address 1.2.3.4:12345 \ - -https-bind-address 1.2.3.4:54321 -``` - -## Running under systemd - -A common deployment pattern is to run the monolith under systemd. For this, you -will need to create a service unit file. An example service unit file is available -in the [GitHub repository](https://github.com/matrix-org/dendrite/blob/main/docs/systemd/monolith-example.service). - -Once you have installed the service unit, you can notify systemd, enable and start -the service: - -```bash -systemctl daemon-reload -systemctl enable dendrite -systemctl start dendrite -journalctl -fu dendrite -``` diff --git a/docs/installation/docker.md b/docs/installation/docker.md new file mode 100644 index 0000000000..1ecc7c6ee9 --- /dev/null +++ b/docs/installation/docker.md @@ -0,0 +1,11 @@ +--- +title: Docker +parent: Installation +has_children: true +nav_order: 4 +permalink: /docker +--- + +# Installation using Docker + +This section contains documentation how to install Dendrite using Docker diff --git a/docs/installation/docker/1_docker.md b/docs/installation/docker/1_docker.md new file mode 100644 index 0000000000..1fe7926366 --- /dev/null +++ b/docs/installation/docker/1_docker.md @@ -0,0 +1,57 @@ +--- +title: Installation +parent: Docker +grand_parent: Installation +has_toc: true +nav_order: 1 +permalink: /installation/docker/install +--- + +# Installing Dendrite using Docker Compose + +Dendrite provides an [example](https://github.com/matrix-org/dendrite/blob/main/build/docker/docker-compose.yml) +Docker compose file, which needs some preparation to start successfully. +Please note that this compose file only has Postgres as a dependency, and you need to configure +a [reverse proxy](../planning#reverse-proxy). + +## Preparations + +### Generate a private key + +First we'll generate private key, which is used to sign events, the following will create one in `./config`: + +```bash +mkdir -p ./config +docker run --rm --entrypoint="/usr/bin/generate-keys" \ + -v $(pwd)/config:/mnt \ + matrixdotorg/dendrite-monolith:latest \ + -private-key /mnt/matrix_key.pem +``` +(**NOTE**: This only needs to be executed **once**, as you otherwise overwrite the key) + +### Generate a config + +Similar to the command above, we can generate a config to be used, which will use the correct paths +as specified in the example docker-compose file. Change `server` to your domain and `db` according to your changes +to the docker-compose file (`services.postgres.environment` values): + +```bash +mkdir -p ./config +docker run --rm --entrypoint="/bin/sh" \ + -v $(pwd)/config:/mnt \ + matrixdotorg/dendrite-monolith:latest \ + -c "/usr/bin/generate-config \ + -dir /var/dendrite/ \ + -db postgres://dendrite:itsasecret@postgres/dendrite?sslmode=disable \ + -server YourDomainHere > /mnt/dendrite.yaml" +``` + +You can then change `config/dendrite.yaml` to your liking. + +## Starting Dendrite + +Once you're done changing the config, you can now start up Dendrite with + +```bash +docker-compose -f docker-compose.yml up +``` diff --git a/docs/installation/helm.md b/docs/installation/helm.md new file mode 100644 index 0000000000..dd20e0261b --- /dev/null +++ b/docs/installation/helm.md @@ -0,0 +1,11 @@ +--- +title: Helm +parent: Installation +has_children: true +nav_order: 3 +permalink: /helm +--- + +# Helm + +This section contains documentation how to use [Helm](https://helm.sh/) to install Dendrite on a [Kubernetes](https://kubernetes.io/) cluster. diff --git a/docs/installation/helm/1_helm.md b/docs/installation/helm/1_helm.md new file mode 100644 index 0000000000..00fe4fdcaf --- /dev/null +++ b/docs/installation/helm/1_helm.md @@ -0,0 +1,58 @@ +--- +title: Installation +parent: Helm +grand_parent: Installation +has_toc: true +nav_order: 1 +permalink: /installation/helm/install +--- + +# Installing Dendrite using Helm + +To install Dendrite using the Helm chart, you first have to add the repository using the following commands: + +```bash +helm repo add dendrite https://matrix-org.github.io/dendrite/ +helm repo update +``` + +Next you'll need to create a `values.yaml` file and configure it to your liking. All possible values can be found +[here](https://github.com/matrix-org/dendrite/blob/main/helm/dendrite/values.yaml), but at least you need to configure +a `server_name`, otherwise the chart will complain about it: + +```yaml +dendrite_config: + global: + server_name: "localhost" +``` + +If you are going to use an existing Postgres database, you'll also need to configure this connection: + +```yaml +dendrite_config: + global: + database: + connection_string: "postgresql://PostgresUser:PostgresPassword@PostgresHostName/DendriteDatabaseName" + max_open_conns: 90 + max_idle_conns: 5 + conn_max_lifetime: -1 +``` + +## Installing with PostgreSQL + +The chart comes with a dependency on Postgres, which can be installed alongside Dendrite, this needs to be enabled in +the `values.yaml`: + +```yaml +postgresql: + enabled: true # this installs Postgres + primary: + persistence: + size: 1Gi # defines the size for $PGDATA + +dendrite_config: + global: + server_name: "localhost" +``` + +Using this option, the `database.connection_string` will be set for you automatically. \ No newline at end of file diff --git a/docs/installation/manual.md b/docs/installation/manual.md new file mode 100644 index 0000000000..3ab1fd6275 --- /dev/null +++ b/docs/installation/manual.md @@ -0,0 +1,11 @@ +--- +title: Manual +parent: Installation +has_children: true +nav_order: 5 +permalink: /manual +--- + +# Manual Installation + +This section contains documentation how to manually install Dendrite diff --git a/docs/installation/3_build.md b/docs/installation/manual/1_build.md similarity index 53% rename from docs/installation/3_build.md rename to docs/installation/manual/1_build.md index 824c81d37f..73a6268820 100644 --- a/docs/installation/3_build.md +++ b/docs/installation/manual/1_build.md @@ -1,31 +1,26 @@ --- -title: Building Dendrite -parent: Installation +title: Building/Installing Dendrite +parent: Manual +grand_parent: Installation has_toc: true -nav_order: 3 -permalink: /installation/build +nav_order: 1 +permalink: /installation/manual/build --- # Build all Dendrite commands Dendrite has numerous utility commands in addition to the actual server binaries. -Build them all from the root of the source repo with `build.sh` (Linux/Mac): +Build them all from the root of the source repo with: ```sh -./build.sh -``` - -or `build.cmd` (Windows): - -```powershell -build.cmd +go build -o bin/ ./cmd/... ``` The resulting binaries will be placed in the `bin` subfolder. -# Installing as a monolith +# Installing Dendrite -You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`: +You can install the Dendrite binary into `$GOPATH/bin` by using `go install`: ```sh go install ./cmd/dendrite diff --git a/docs/installation/4_database.md b/docs/installation/manual/2_database.md similarity index 57% rename from docs/installation/4_database.md rename to docs/installation/manual/2_database.md index d64ee6615d..1be602c663 100644 --- a/docs/installation/4_database.md +++ b/docs/installation/manual/2_database.md @@ -1,8 +1,10 @@ --- title: Preparing database storage parent: Installation -nav_order: 3 -permalink: /installation/database +nav_order: 2 +parent: Manual +grand_parent: Installation +permalink: /installation/manual/database --- # Preparing database storage @@ -13,31 +15,22 @@ may need to perform some manual steps outlined below. ## PostgreSQL Dendrite can automatically populate the database with the relevant tables and indexes, but -it is not capable of creating the databases themselves. You will need to create the databases +it is not capable of creating the database itself. You will need to create the database manually. -The databases **must** be created with UTF-8 encoding configured or you will likely run into problems +The database **must** be created with UTF-8 encoding configured, or you will likely run into problems with your Dendrite deployment. -At this point, you can choose to either use a single database for all Dendrite components, -or you can run each component with its own separate database: +You will need to create a single PostgreSQL database. Deployments +can use a single global connection pool, which makes updating the configuration file much easier. +Only one database connection string to manage and likely simpler to back up the database. All +components will be sharing the same database resources (CPU, RAM, storage). -* **Single database**: You will need to create a single PostgreSQL database. Monolith deployments - can use a single global connection pool, which makes updating the configuration file much easier. - Only one database connection string to manage and likely simpler to back up the database. All - components will be sharing the same database resources (CPU, RAM, storage). - -* **Separate databases**: You will need to create a separate PostgreSQL database for each - component. You will need to configure each component that has storage in the Dendrite - configuration file with its own connection parameters. Allows running a different database engine - for each component on a different machine if needs be, each with their own CPU, RAM and storage — - almost certainly overkill unless you are running a very large Dendrite deployment. - -For either configuration, you will want to: +You will most likely want to: 1. Configure a role (with a username and password) which Dendrite can use to connect to the database; -2. Create the database(s) themselves, ensuring that the Dendrite role has privileges over them. +2. Create the database itself, ensuring that the Dendrite role has privileges over them. As Dendrite will create and manage the database tables, indexes and sequences by itself, the Dendrite role must have suitable privileges over the database. @@ -71,27 +64,6 @@ Create the database itself, using the `dendrite` role from above: sudo -u postgres createdb -O dendrite -E UTF-8 dendrite ``` -### Multiple database creation - -The following eight components require a database. In this example they will be named: - -| Appservice API | `dendrite_appservice` | -| Federation API | `dendrite_federationapi` | -| Media API | `dendrite_mediaapi` | -| MSCs | `dendrite_mscs` | -| Roomserver | `dendrite_roomserver` | -| Sync API | `dendrite_syncapi` | -| Key server | `dendrite_keyserver` | -| User API | `dendrite_userapi` | - -... therefore you will need to create eight different databases: - -```bash -for i in appservice federationapi mediaapi mscs roomserver syncapi keyserver userapi; do - sudo -u postgres createdb -O dendrite -E UTF-8 dendrite_$i -done -``` - ## SQLite **WARNING:** The Dendrite SQLite backend is slower, less reliable and not recommended for diff --git a/docs/installation/8_signingkey.md b/docs/installation/manual/3_signingkey.md similarity index 92% rename from docs/installation/8_signingkey.md rename to docs/installation/manual/3_signingkey.md index 323759a88f..91289fd6ac 100644 --- a/docs/installation/8_signingkey.md +++ b/docs/installation/manual/3_signingkey.md @@ -1,8 +1,9 @@ --- title: Generating signing keys -parent: Installation -nav_order: 8 -permalink: /installation/signingkeys +parent: Manual +grand_parent: Installation +nav_order: 3 +permalink: /installation/manual/signingkeys --- # Generating signing keys @@ -11,7 +12,7 @@ All Matrix homeservers require a signing private key, which will be used to auth federation requests and events. The `generate-keys` utility can be used to generate a private key. Assuming that Dendrite was -built using `build.sh`, you should find the `generate-keys` utility in the `bin` folder. +built using `go build -o bin/ ./cmd/...`, you should find the `generate-keys` utility in the `bin` folder. To generate a Matrix signing private key: diff --git a/docs/installation/7_configuration.md b/docs/installation/manual/4_configuration.md similarity index 66% rename from docs/installation/7_configuration.md rename to docs/installation/manual/4_configuration.md index 0cc67b1561..624cc4155f 100644 --- a/docs/installation/7_configuration.md +++ b/docs/installation/manual/4_configuration.md @@ -1,8 +1,9 @@ --- title: Configuring Dendrite -parent: Installation -nav_order: 7 -permalink: /installation/configuration +parent: Manual +grand_parent: Installation +nav_order: 4 +permalink: /installation/manual/configuration --- # Configuring Dendrite @@ -20,7 +21,7 @@ sections: First of all, you will need to configure the server name of your Matrix homeserver. This must match the domain name that you have selected whilst [configuring the domain -name delegation](domainname). +name delegation](../domainname#delegation). In the `global` section, set the `server_name` to your delegated domain name: @@ -44,7 +45,7 @@ global: ## JetStream configuration -Monolith deployments can use the built-in NATS Server rather than running a standalone +Dendrite deployments can use the built-in NATS Server rather than running a standalone server. If you want to use a standalone NATS Server anyway, you can also configure that too. ### Built-in NATS Server @@ -56,7 +57,6 @@ configured and set a `storage_path` to a persistent folder on the filesystem: global: # ... jetstream: - in_memory: false storage_path: /path/to/storage/folder topic_prefix: Dendrite ``` @@ -79,22 +79,17 @@ You do not need to configure the `storage_path` when using a standalone NATS Ser In the case that you are connecting to a multi-node NATS cluster, you can configure more than one address in the `addresses` field. -## Database connections +## Database connection using a global connection pool -Configuring database connections varies based on the [database configuration](database) -that you chose. - -### Global connection pool - -If you want to use a single connection pool to a single PostgreSQL database, then you must -uncomment and configure the `database` section within the `global` section: +If you want to use a single connection pool to a single PostgreSQL database, +then you must uncomment and configure the `database` section within the `global` section: ```yaml global: # ... database: connection_string: postgres://user:pass@hostname/database?sslmode=disable - max_open_conns: 100 + max_open_conns: 90 max_idle_conns: 5 conn_max_lifetime: -1 ``` @@ -104,42 +99,13 @@ configuration file, e.g. under the `app_service_api`, `federation_api`, `key_ser `media_api`, `mscs`, `relay_api`, `room_server`, `sync_api` and `user_api` blocks, otherwise these will override the `global` database configuration. -### Per-component connections (all other configurations) - -If you are are using SQLite databases or separate PostgreSQL -databases per component, then you must instead configure the `database` sections under each -of the component blocks ,e.g. under the `app_service_api`, `federation_api`, `key_server`, -`media_api`, `mscs`, `relay_api`, `room_server`, `sync_api` and `user_api` blocks. - -For example, with PostgreSQL: - -```yaml -room_server: - # ... - database: - connection_string: postgres://user:pass@hostname/dendrite_component?sslmode=disable - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 -``` - -... or with SQLite: - -```yaml -room_server: - # ... - database: - connection_string: file:roomserver.db - max_open_conns: 10 - max_idle_conns: 2 - conn_max_lifetime: -1 -``` - ## Full-text search -Dendrite supports experimental full-text indexing using [Bleve](https://github.com/blevesearch/bleve). It is configured in the `sync_api` section as follows. +Dendrite supports full-text indexing using [Bleve](https://github.com/blevesearch/bleve). It is configured in the `sync_api` section as follows. -Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expectations. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang). +Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, +to ensure the returned results match the expectations. A full list of possible languages +can be found [here](https://github.com/matrix-org/dendrite/blob/5b73592f5a4dddf64184fcbe33f4c1835c656480/internal/fulltext/bleve.go#L25-L46). ```yaml sync_api: diff --git a/docs/installation/manual/5_starting_dendrite.md b/docs/installation/manual/5_starting_dendrite.md new file mode 100644 index 0000000000..d135043720 --- /dev/null +++ b/docs/installation/manual/5_starting_dendrite.md @@ -0,0 +1,26 @@ +--- +title: Starting Dendrite +parent: Manual +grand_parent: Installation +nav_order: 5 +permalink: /installation/manual/start +--- + +# Starting Dendrite + +Once you have completed all preparation and installation steps, +you can start your Dendrite deployment by executing the `dendrite` binary: + +```bash +./dendrite -config /path/to/dendrite.yaml +``` + +By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses +or ports that Dendrite listens on, you can use the `-http-bind-address` and +`-https-bind-address` command line arguments: + +```bash +./dendrite -config /path/to/dendrite.yaml \ + -http-bind-address 1.2.3.4:12345 \ + -https-bind-address 1.2.3.4:54321 +``` diff --git a/docs/nginx/monolith-sample.conf b/docs/nginx/dendrite-sample.conf similarity index 100% rename from docs/nginx/monolith-sample.conf rename to docs/nginx/dendrite-sample.conf diff --git a/docs/nginx/polylith-sample.conf b/docs/nginx/polylith-sample.conf deleted file mode 100644 index 0ad24509a4..0000000000 --- a/docs/nginx/polylith-sample.conf +++ /dev/null @@ -1,58 +0,0 @@ -server { - listen 443 ssl; # IPv4 - listen [::]:443 ssl; # IPv6 - server_name my.hostname.com; - - ssl_certificate /path/to/fullchain.pem; - ssl_certificate_key /path/to/privkey.pem; - ssl_dhparam /path/to/ssl-dhparams.pem; - - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_read_timeout 600; - - location /.well-known/matrix/server { - return 200 '{ "m.server": "my.hostname.com:443" }'; - } - - location /.well-known/matrix/client { - # If your sever_name here doesn't match your matrix homeserver URL - # (e.g. hostname.com as server_name and matrix.hostname.com as homeserver URL) - # add_header Access-Control-Allow-Origin '*'; - return 200 '{ "m.homeserver": { "base_url": "https://my.hostname.com" } }'; - } - - # route requests to: - # /_matrix/client/.*/sync - # /_matrix/client/.*/user/{userId}/filter - # /_matrix/client/.*/user/{userId}/filter/{filterID} - # /_matrix/client/.*/keys/changes - # /_matrix/client/.*/rooms/{roomId}/messages - # /_matrix/client/.*/rooms/{roomId}/context/{eventID} - # /_matrix/client/.*/rooms/{roomId}/event/{eventID} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType} - # /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}/{eventType} - # /_matrix/client/.*/rooms/{roomId}/members - # /_matrix/client/.*/rooms/{roomId}/joined_members - # to sync_api - location ~ /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|.*?_?members|context/.*?|relations/.*?|event/.*?))$ { - proxy_pass http://sync_api:8073; - } - - location /_matrix/client { - proxy_pass http://client_api:8071; - } - - location /_matrix/federation { - proxy_pass http://federation_api:8072; - } - - location /_matrix/key { - proxy_pass http://federation_api:8072; - } - - location /_matrix/media { - proxy_pass http://media_api:8074; - } -} diff --git a/docs/systemd/monolith-example.service b/docs/systemd/monolith-example.service deleted file mode 100644 index 8a948a3faa..0000000000 --- a/docs/systemd/monolith-example.service +++ /dev/null @@ -1,19 +0,0 @@ -[Unit] -Description=Dendrite (Matrix Homeserver) -After=syslog.target -After=network.target -After=postgresql.service - -[Service] -Environment=GODEBUG=madvdontneed=1 -RestartSec=2s -Type=simple -User=dendrite -Group=dendrite -WorkingDirectory=/opt/dendrite/ -ExecStart=/opt/dendrite/bin/dendrite -Restart=always -LimitNOFILE=65535 - -[Install] -WantedBy=multi-user.target diff --git a/federationapi/api/api.go b/federationapi/api/api.go index e23bec2714..5b49e509e5 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -8,13 +8,16 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/federationapi/types" + rstypes "github.com/matrix-org/dendrite/roomserver/types" ) // FederationInternalAPI is used to query information from the federation sender. type FederationInternalAPI interface { gomatrixserverlib.FederatedStateClient + gomatrixserverlib.FederatedJoinClient KeyserverFederationAPI gomatrixserverlib.KeyDatabase ClientFederationAPI @@ -22,9 +25,9 @@ type FederationInternalAPI interface { P2PFederationAPI QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error - LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) - MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res fclient.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res fclient.MSC2946SpacesResponse, err error) + LookupServerKeys(ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp) ([]gomatrixserverlib.ServerKeys, error) + MSC2836EventRelationships(ctx context.Context, origin, dst spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res fclient.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst spec.ServerName, roomID string, suggestedOnly bool) (res fclient.MSC2946SpacesResponse, err error) // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. PerformBroadcastEDU( @@ -59,7 +62,7 @@ type RoomserverFederationAPI interface { // Handle an instruction to make_leave & send_leave with a remote server. PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse) error // Handle sending an invite to a remote server. - PerformInvite(ctx context.Context, request *PerformInviteRequest, response *PerformInviteResponse) error + SendInvite(ctx context.Context, event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error) // Handle an instruction to peek a room on a remote server. PerformOutboundPeek(ctx context.Context, request *PerformOutboundPeekRequest, response *PerformOutboundPeekResponse) error // Query the server names of the joined hosts in a room. @@ -67,9 +70,9 @@ type RoomserverFederationAPI interface { // containing only the server names (without information for membership events). // The response will include this server if they are joined to the room. QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error - GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error) - GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error) + GetEventAuth(ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error) + GetEvent(ctx context.Context, origin, s spec.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error) } type P2PFederationAPI interface { @@ -99,45 +102,9 @@ type P2PFederationAPI interface { // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError type KeyserverFederationAPI interface { - GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (res fclient.RespUserDevices, err error) - ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res fclient.RespClaimKeys, err error) - QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (res fclient.RespQueryKeys, err error) -} - -// an interface for gmsl.FederationClient - contains functions called by federationapi only. -type FederationClient interface { - P2PFederationClient - gomatrixserverlib.KeyClient - SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res fclient.RespSend, err error) - - // Perform operations - LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res fclient.RespDirectory, err error) - Peek(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res fclient.RespPeek, err error) - MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res fclient.RespMakeJoin, err error) - SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res fclient.RespSendJoin, err error) - MakeLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string) (res fclient.RespMakeLeave, err error) - SendLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) - SendInviteV2(ctx context.Context, origin, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res fclient.RespInviteV2, err error) - - GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - - GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res fclient.RespEventAuth, err error) - GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (fclient.RespUserDevices, error) - ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (fclient.RespClaimKeys, error) - QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (fclient.RespQueryKeys, error) - Backfill(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) - MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res fclient.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res fclient.MSC2946SpacesResponse, err error) - - ExchangeThirdPartyInvite(ctx context.Context, origin, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) - LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespState, err error) - LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res fclient.RespStateIDs, err error) - LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res fclient.RespMissingEvents, err error) -} - -type P2PFederationClient interface { - P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res fclient.EmptyResp, err error) - P2PGetTransactionFromRelay(ctx context.Context, u gomatrixserverlib.UserID, prev fclient.RelayEntry, relayServer gomatrixserverlib.ServerName) (res fclient.RespGetRelayTransaction, err error) + GetUserDevices(ctx context.Context, origin, s spec.ServerName, userID string) (res fclient.RespUserDevices, err error) + ClaimKeys(ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string) (res fclient.RespClaimKeys, err error) + QueryKeys(ctx context.Context, origin, s spec.ServerName, keys map[string][]string) (res fclient.RespQueryKeys, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. @@ -153,7 +120,7 @@ func (e FederationClientError) Error() string { } type QueryServerKeysRequest struct { - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName KeyIDToCriteria map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria } @@ -172,7 +139,7 @@ type QueryServerKeysResponse struct { } type QueryPublicKeysRequest struct { - Requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp `json:"requests"` + Requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp `json:"requests"` } type QueryPublicKeysResponse struct { @@ -180,13 +147,13 @@ type QueryPublicKeysResponse struct { } type PerformDirectoryLookupRequest struct { - RoomAlias string `json:"room_alias"` - ServerName gomatrixserverlib.ServerName `json:"server_name"` + RoomAlias string `json:"room_alias"` + ServerName spec.ServerName `json:"server_name"` } type PerformDirectoryLookupResponse struct { - RoomID string `json:"room_id"` - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` + RoomID string `json:"room_id"` + ServerNames []spec.ServerName `json:"server_names"` } type PerformJoinRequest struct { @@ -199,7 +166,7 @@ type PerformJoinRequest struct { } type PerformJoinResponse struct { - JoinedVia gomatrixserverlib.ServerName + JoinedVia spec.ServerName LastError *gomatrix.HTTPError } @@ -223,13 +190,13 @@ type PerformLeaveResponse struct { } type PerformInviteRequest struct { - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - Event *gomatrixserverlib.HeaderedEvent `json:"event"` - InviteRoomState []gomatrixserverlib.InviteV2StrippedState `json:"invite_room_state"` + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + Event *rstypes.HeaderedEvent `json:"event"` + InviteRoomState []gomatrixserverlib.InviteStrippedState `json:"invite_room_state"` } type PerformInviteResponse struct { - Event *gomatrixserverlib.HeaderedEvent `json:"event"` + Event *rstypes.HeaderedEvent `json:"event"` } // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames @@ -241,7 +208,7 @@ type QueryJoinedHostServerNamesInRoomRequest struct { // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomResponse struct { - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` + ServerNames []spec.ServerName `json:"server_names"` } type PerformBroadcastEDURequest struct { @@ -251,7 +218,7 @@ type PerformBroadcastEDUResponse struct { } type PerformWakeupServersRequest struct { - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` + ServerNames []spec.ServerName `json:"server_names"` } type PerformWakeupServersResponse struct { @@ -265,24 +232,24 @@ type InputPublicKeysResponse struct { } type P2PQueryRelayServersRequest struct { - Server gomatrixserverlib.ServerName + Server spec.ServerName } type P2PQueryRelayServersResponse struct { - RelayServers []gomatrixserverlib.ServerName + RelayServers []spec.ServerName } type P2PAddRelayServersRequest struct { - Server gomatrixserverlib.ServerName - RelayServers []gomatrixserverlib.ServerName + Server spec.ServerName + RelayServers []spec.ServerName } type P2PAddRelayServersResponse struct { } type P2PRemoveRelayServersRequest struct { - Server gomatrixserverlib.ServerName - RelayServers []gomatrixserverlib.ServerName + Server spec.ServerName + RelayServers []spec.ServerName } type P2PRemoveRelayServersResponse struct { diff --git a/federationapi/api/servers.go b/federationapi/api/servers.go deleted file mode 100644 index 6bb15763d5..0000000000 --- a/federationapi/api/servers.go +++ /dev/null @@ -1,11 +0,0 @@ -package api - -import ( - "context" - - "github.com/matrix-org/gomatrixserverlib" -) - -type ServersInRoomProvider interface { - GetServersForRoom(ctx context.Context, roomID string, event *gomatrixserverlib.Event) []gomatrixserverlib.ServerName -} diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 7d9df3d78a..3fdc835bb1 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -20,6 +20,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -40,7 +41,7 @@ type KeyChangeConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool rsAPI roomserverAPI.FederationRoomserverAPI topic string } @@ -140,7 +141,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MDeviceListUpdate, + Type: spec.MDeviceListUpdate, Origin: string(originServerName), } event := gomatrixserverlib.DeviceListUpdateEvent{ diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index 29b16f3738..e751b65d4b 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" ) @@ -39,7 +40,7 @@ type OutputPresenceConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool rsAPI roomserverAPI.FederationRoomserverAPI topic string outboundPresenceEnabled bool @@ -127,7 +128,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg statusMsg = &status } - p := types.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(ts)} + p := types.PresenceInternal{LastActiveTS: spec.Timestamp(ts)} content := fedTypes.Presence{ Push: []fedTypes.PresenceContent{ @@ -142,7 +143,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg } edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MPresence, + Type: spec.MPresence, Origin: string(serverName), } if edu.Content, err = json.Marshal(content); err != nil { diff --git a/federationapi/consumers/receipts.go b/federationapi/consumers/receipts.go index 200c06e6c9..1407a88b77 100644 --- a/federationapi/consumers/receipts.go +++ b/federationapi/consumers/receipts.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" syncTypes "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" ) @@ -39,7 +40,7 @@ type OutputReceiptConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool topic string } @@ -107,7 +108,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) return true } - receipt.Timestamp = gomatrixserverlib.Timestamp(timestamp) + receipt.Timestamp = spec.Timestamp(timestamp) joined, err := t.db.GetJoinedHosts(ctx, receipt.RoomID) if err != nil { @@ -115,7 +116,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) return false } - names := make([]gomatrixserverlib.ServerName, len(joined)) + names := make([]spec.ServerName, len(joined)) for i := range joined { names[i] = joined[i].ServerName } @@ -133,7 +134,7 @@ func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) } edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MReceipt, + Type: spec.MReceipt, Origin: string(receiptServerName), } if edu.Content, err = json.Marshal(content); err != nil { diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 378b96ba07..6dd2fd345a 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -22,6 +22,7 @@ import ( "time" syncAPITypes "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -186,7 +187,12 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew addsStateEvents = append(addsStateEvents, eventsRes.Events...) } - addsJoinedHosts, err := JoinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents)) + evs := make([]gomatrixserverlib.PDU, len(addsStateEvents)) + for i := range evs { + evs[i] = addsStateEvents[i].PDU + } + + addsJoinedHosts, err := JoinedHostsFromEvents(s.ctx, evs, s.rsAPI) if err != nil { return err } @@ -207,9 +213,9 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } // If we added new hosts, inform them about our known presence events for this room - if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { + if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == spec.MRoomMember && ore.Event.StateKey() != nil { membership, _ := ore.Event.Membership() - if membership == gomatrixserverlib.Join { + if membership == spec.Join { s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) } } @@ -239,12 +245,12 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew // Send the event. return s.queues.SendEvent( - ore.Event, gomatrixserverlib.ServerName(ore.SendAsServer), joinedHostsAtEvent, + ore.Event, spec.ServerName(ore.SendAsServer), joinedHostsAtEvent, ) } func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) { - joined := make([]gomatrixserverlib.ServerName, 0, len(addedJoined)) + joined := make([]spec.ServerName, 0, len(addedJoined)) for _, added := range addedJoined { joined = append(joined, added.ServerName) } @@ -285,7 +291,7 @@ func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []type continue } - p := syncAPITypes.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(lastActive)} + p := syncAPITypes.PresenceInternal{LastActiveTS: spec.Timestamp(lastActive)} content.Push = append(content.Push, types.PresenceContent{ CurrentlyActive: p.CurrentlyActive(), @@ -301,7 +307,7 @@ func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []type } edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MPresence, + Type: spec.MPresence, Origin: string(s.cfg.Matrix.ServerName), } if edu.Content, err = json.Marshal(content); err != nil { @@ -326,7 +332,7 @@ func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []type // Returns an error if there was a problem talking to the room server. func (s *OutputRoomEventConsumer) joinedHostsAtEvent( ore api.OutputNewRoomEvent, oldJoinedHosts []types.JoinedHost, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { // Combine the delta into a single delta so that the adds and removes can // cancel each other out. This should reduce the number of times we need // to fetch a state event from the room server. @@ -334,12 +340,12 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( ore.AddsStateEventIDs, ore.RemovesStateEventIDs, ore.StateBeforeAddsEventIDs, ore.StateBeforeRemovesEventIDs, ) - combinedAddsEvents, err := s.lookupStateEvents(combinedAdds, ore.Event.Event) + combinedAddsEvents, err := s.lookupStateEvents(combinedAdds, ore.Event.PDU) if err != nil { return nil, err } - combinedAddsJoinedHosts, err := JoinedHostsFromEvents(combinedAddsEvents) + combinedAddsJoinedHosts, err := JoinedHostsFromEvents(s.ctx, combinedAddsEvents, s.rsAPI) if err != nil { return nil, err } @@ -349,7 +355,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( removed[eventID] = true } - joined := map[gomatrixserverlib.ServerName]bool{} + joined := map[spec.ServerName]bool{} for _, joinedHost := range oldJoinedHosts { if removed[joinedHost.MemberEventID] { // This m.room.member event is part of the current state of the @@ -368,7 +374,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( } // handle peeking hosts - inboundPeeks, err := s.db.GetInboundPeeks(s.ctx, ore.Event.Event.RoomID()) + inboundPeeks, err := s.db.GetInboundPeeks(s.ctx, ore.Event.PDU.RoomID()) if err != nil { return nil, err } @@ -376,7 +382,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( joined[inboundPeek.ServerName] = true } - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for serverName, include := range joined { if include { result = append(result, serverName) @@ -388,7 +394,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( // JoinedHostsFromEvents turns a list of state events into a list of joined hosts. // This errors if one of the events was invalid. // It should be impossible for an invalid event to get this far in the pipeline. -func JoinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) { +func JoinedHostsFromEvents(ctx context.Context, evs []gomatrixserverlib.PDU, rsAPI api.FederationRoomserverAPI) ([]types.JoinedHost, error) { var joinedHosts []types.JoinedHost for _, ev := range evs { if ev.Type() != "m.room.member" || ev.StateKey() == nil { @@ -398,15 +404,20 @@ func JoinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, if err != nil { return nil, err } - if membership != gomatrixserverlib.Join { + if membership != spec.Join { continue } - _, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return nil, err + } + userID, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey())) if err != nil { return nil, err } + joinedHosts = append(joinedHosts, types.JoinedHost{ - MemberEventID: ev.EventID(), ServerName: serverName, + MemberEventID: ev.EventID(), ServerName: userID.Domain(), }) } return joinedHosts, nil @@ -453,8 +464,8 @@ func combineDeltas(adds1, removes1, adds2, removes2 []string) (adds, removes []s // lookupStateEvents looks up the state events that are added by a new event. func (s *OutputRoomEventConsumer) lookupStateEvents( - addsStateEventIDs []string, event *gomatrixserverlib.Event, -) ([]*gomatrixserverlib.Event, error) { + addsStateEventIDs []string, event gomatrixserverlib.PDU, +) ([]gomatrixserverlib.PDU, error) { // Fast path if there aren't any new state events. if len(addsStateEventIDs) == 0 { return nil, nil @@ -462,11 +473,11 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( // Fast path if the only state event added is the event itself. if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { - return []*gomatrixserverlib.Event{event}, nil + return []gomatrixserverlib.PDU{event}, nil } missing := addsStateEventIDs - var result []*gomatrixserverlib.Event + var result []gomatrixserverlib.PDU // Check if event itself is being added. for _, eventID := range missing { @@ -491,7 +502,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( } for _, headeredEvent := range eventResp.Events { - result = append(result, headeredEvent.Event) + result = append(result, headeredEvent.PDU) } missing = missingEventsFrom(result, addsStateEventIDs) @@ -505,7 +516,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( return result, nil } -func missingEventsFrom(events []*gomatrixserverlib.Event, required []string) []string { +func missingEventsFrom(events []gomatrixserverlib.PDU, required []string) []string { have := map[string]bool{} for _, event := range events { have[event.EventID()] = true diff --git a/federationapi/consumers/sendtodevice.go b/federationapi/consumers/sendtodevice.go index 9620d16120..91b28cdbfb 100644 --- a/federationapi/consumers/sendtodevice.go +++ b/federationapi/consumers/sendtodevice.go @@ -20,6 +20,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -39,7 +40,7 @@ type OutputSendToDeviceConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool topic string } @@ -107,7 +108,7 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MDirectToDevice, + Type: spec.MDirectToDevice, Origin: string(originServerName), } tdm := gomatrixserverlib.ToDeviceMessage{ @@ -127,7 +128,7 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats } log.Debugf("Sending send-to-device message into %q destination queue", destServerName) - if err := t.queues.SendEDU(edu, originServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { + if err := t.queues.SendEDU(edu, originServerName, []spec.ServerName{destServerName}); err != nil { log.WithError(err).Error("failed to send EDU") return false } diff --git a/federationapi/consumers/typing.go b/federationapi/consumers/typing.go index c66f97519f..134f2174f3 100644 --- a/federationapi/consumers/typing.go +++ b/federationapi/consumers/typing.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" ) @@ -36,7 +37,7 @@ type OutputTypingConsumer struct { durable string db storage.Database queues *queue.OutgoingQueues - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool topic string } @@ -97,7 +98,7 @@ func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) return false } - names := make([]gomatrixserverlib.ServerName, len(joined)) + names := make([]spec.ServerName, len(joined)) for i := range joined { names[i] = joined[i].ServerName } diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index c64fa550de..ee15a8a6e0 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -49,11 +49,10 @@ func AddPublicRoutes( dendriteConfig *config.Dendrite, natsInstance *jetstream.NATSInstance, userAPI userapi.FederationUserAPI, - federation *fclient.FederationClient, + federation fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, rsAPI roomserverAPI.FederationRoomserverAPI, fedAPI federationAPI.FederationInternalAPI, - servers federationAPI.ServersInRoomProvider, enableMetrics bool, ) { cfg := &dendriteConfig.FederationAPI @@ -87,7 +86,7 @@ func AddPublicRoutes( dendriteConfig, rsAPI, f, keyRing, federation, userAPI, mscCfg, - servers, producer, enableMetrics, + producer, enableMetrics, ) } @@ -98,7 +97,7 @@ func NewInternalAPI( dendriteCfg *config.Dendrite, cm sqlutil.Connections, natsInstance *jetstream.NATSInstance, - federation api.FederationClient, + federation fclient.FederationClient, rsAPI roomserverAPI.FederationRoomserverAPI, caches *caching.Caches, keyRing *gomatrixserverlib.KeyRing, diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 2fa748bade..9dda389ed9 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -17,6 +17,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/routing" @@ -25,12 +26,12 @@ import ( ) type server struct { - name gomatrixserverlib.ServerName // server name - validity time.Duration // key validity duration from now - config *config.FederationAPI // skeleton config, from TestMain - fedclient *fclient.FederationClient // uses MockRoundTripper - cache *caching.Caches // server-specific cache - api api.FederationInternalAPI // server-specific server key API + name spec.ServerName // server name + validity time.Duration // key validity duration from now + config *config.FederationAPI // skeleton config, from TestMain + fedclient fclient.FederationClient // uses MockRoundTripper + cache *caching.Caches // server-specific cache + api api.FederationInternalAPI // server-specific server key API } func (s *server) renew() { @@ -83,7 +84,7 @@ func TestMain(m *testing.M) { Generate: true, SingleDatabase: false, }) - cfg.Global.ServerName = gomatrixserverlib.ServerName(s.name) + cfg.Global.ServerName = spec.ServerName(s.name) cfg.Global.PrivateKey = testPriv cfg.Global.JetStream.InMemory = true cfg.Global.JetStream.TopicPrefix = string(s.name[:1]) @@ -141,7 +142,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err } // Get the keys and JSON-ify them. - keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host)) + keys := routing.LocalKeys(s.config, spec.ServerName(req.Host)) body, err := json.MarshalIndent(keys.JSON, "", " ") if err != nil { return nil, err @@ -166,8 +167,8 @@ func TestServersRequestOwnKeys(t *testing.T) { } res, err := s.api.FetchKeys( context.Background(), - map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ - req: gomatrixserverlib.AsTimestamp(time.Now()), + map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{ + req: spec.AsTimestamp(time.Now()), }, ) if err != nil { @@ -192,8 +193,8 @@ func TestRenewalBehaviour(t *testing.T) { res, err := serverA.api.FetchKeys( context.Background(), - map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ - req: gomatrixserverlib.AsTimestamp(time.Now()), + map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{ + req: spec.AsTimestamp(time.Now()), }, ) if err != nil { @@ -216,8 +217,8 @@ func TestRenewalBehaviour(t *testing.T) { res, err = serverA.api.FetchKeys( context.Background(), - map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{ - req: gomatrixserverlib.AsTimestamp(time.Now()), + map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{ + req: spec.AsTimestamp(time.Now()), }, ) if err != nil { diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 3c01a82596..5d167c0eea 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -16,12 +16,14 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/internal" rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" @@ -34,13 +36,20 @@ type fedRoomserverAPI struct { queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error } +func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + +func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + return spec.SenderID(userID.String()), nil +} + // PerformJoin will call this function -func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) error { +func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { if f.inputRoomEvents == nil { - return nil + return } f.inputRoomEvents(ctx, req, res) - return nil } // keychange consumer calls this @@ -54,9 +63,9 @@ func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.Que // TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate type fedClient struct { fedClientMutex sync.Mutex - api.FederationClient + fclient.FederationClient allowJoins []*test.Room - keys map[gomatrixserverlib.ServerName]struct { + keys map[spec.ServerName]struct { key ed25519.PrivateKey keyID gomatrixserverlib.KeyID } @@ -64,7 +73,7 @@ type fedClient struct { sentTxn bool } -func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) { +func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer spec.ServerName) (gomatrixserverlib.ServerKeys, error) { f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() fmt.Println("GetServerKeys:", matrixServer) @@ -83,11 +92,11 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv } keys.ServerName = matrixServer - keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(10 * time.Hour)) + keys.ValidUntilTS = spec.AsTimestamp(time.Now().Add(10 * time.Hour)) publicKey := pkey.Public().(ed25519.PublicKey) keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ keyID: { - Key: gomatrixserverlib.Base64Bytes(publicKey), + Key: spec.Base64Bytes(publicKey), }, } toSign, err := json.Marshal(keys.ServerKeyFields) @@ -105,20 +114,23 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv return keys, nil } -func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res fclient.RespMakeJoin, err error) { +func (f *fedClient) MakeJoin(ctx context.Context, origin, s spec.ServerName, roomID, userID string) (res fclient.RespMakeJoin, err error) { + f.fedClientMutex.Lock() + defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { if r.ID == roomID { + senderIDString := userID res.RoomVersion = r.Version - res.JoinEvent = gomatrixserverlib.EventBuilder{ - Sender: userID, + res.JoinEvent = gomatrixserverlib.ProtoEvent{ + SenderID: senderIDString, RoomID: roomID, Type: "m.room.member", - StateKey: &userID, - Content: gomatrixserverlib.RawJSON([]byte(`{"membership":"join"}`)), + StateKey: &senderIDString, + Content: spec.RawJSON([]byte(`{"membership":"join"}`)), PrevEvents: r.ForwardExtremities(), } var needed gomatrixserverlib.StateNeeded - needed, err = gomatrixserverlib.StateNeededForEventBuilder(&res.JoinEvent) + needed, err = gomatrixserverlib.StateNeededForProtoEvent(&res.JoinEvent) if err != nil { f.t.Errorf("StateNeededForEventBuilder: %v", err) return @@ -129,15 +141,15 @@ func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.Se } return } -func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res fclient.RespSendJoin, err error) { +func (f *fedClient) SendJoin(ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU) (res fclient.RespSendJoin, err error) { f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { if r.ID == event.RoomID() { - r.InsertEvent(f.t, event.Headered(r.Version)) + r.InsertEvent(f.t, &types.HeaderedEvent{PDU: event}) f.t.Logf("Join event: %v", event.EventID()) - res.StateEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.CurrentState()) - res.AuthEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.Events()) + res.StateEvents = types.NewEventJSONsFromHeaderedEvents(r.CurrentState()) + res.AuthEvents = types.NewEventJSONsFromHeaderedEvents(r.Events()) } } return @@ -147,7 +159,7 @@ func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Tra f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() for _, edu := range t.EDUs { - if edu.Type == gomatrixserverlib.MDeviceListUpdate { + if edu.Type == spec.MDeviceListUpdate { f.sentTxn = true } } @@ -174,7 +186,7 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) - serverA := gomatrixserverlib.ServerName("server.a") + serverA := spec.ServerName("server.a") serverAKeyID := gomatrixserverlib.KeyID("ed25519:servera") serverAPrivKey := test.PrivateKeyA creator := test.NewUser(t, test.WithSigningServer(serverA, serverAKeyID, serverAPrivKey)) @@ -203,7 +215,7 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { fc := &fedClient{ allowJoins: []*test.Room{room}, t: t, - keys: map[gomatrixserverlib.ServerName]struct { + keys: map[spec.ServerName]struct { key ed25519.PrivateKey keyID gomatrixserverlib.KeyID }{ @@ -223,7 +235,7 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { fsapi.PerformJoin(context.Background(), &api.PerformJoinRequest{ RoomID: room.ID, UserID: joiningUser.ID, - ServerNames: []gomatrixserverlib.ServerName{serverA}, + ServerNames: []spec.ServerName{serverA}, }, &resp) if resp.JoinedVia != serverA { t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, serverA) @@ -302,17 +314,17 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { _, privKey, _ := ed25519.GenerateKey(nil) cfg.Global.KeyID = gomatrixserverlib.KeyID("ed25519:auto") - cfg.Global.ServerName = gomatrixserverlib.ServerName("localhost") + cfg.Global.ServerName = spec.ServerName("localhost") cfg.Global.PrivateKey = privKey cfg.Global.JetStream.InMemory = true keyRing := &test.NopJSONVerifier{} natsInstance := jetstream.NATSInstance{} // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, caching.DisableMetrics) + federationapi.AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, caching.DisableMetrics) baseURL, cancel := test.ListenAndServe(t, routers.Federation, true) defer cancel() - serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) + serverName := spec.ServerName(strings.TrimPrefix(baseURL, "https://")) fedCli := fclient.NewFederationClient( cfg.Global.SigningIdentities(), @@ -320,12 +332,11 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { ) for _, tc := range testCases { - ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(tc.eventJSON), false, tc.roomVer) + ev, err := gomatrixserverlib.MustGetRoomVersion(tc.roomVer).NewEventFromTrustedJSON([]byte(tc.eventJSON), false) if err != nil { t.Errorf("failed to parse event: %s", err) } - he := ev.Headered(tc.roomVer) - invReq, err := gomatrixserverlib.NewInviteV2Request(he, nil) + invReq, err := fclient.NewInviteV2Request(ev, nil) if err != nil { t.Errorf("failed to create invite v2 request: %s", err) continue diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 99773a7508..aa501f63c0 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -17,6 +17,8 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -26,7 +28,7 @@ type FederationInternalAPI struct { cfg *config.FederationAPI statistics *statistics.Statistics rsAPI roomserverAPI.FederationRoomserverAPI - federation api.FederationClient + federation fclient.FederationClient keyRing *gomatrixserverlib.KeyRing queues *queue.OutgoingQueues joins sync.Map // joins currently in progress @@ -35,7 +37,7 @@ type FederationInternalAPI struct { func NewFederationInternalAPI( db storage.Database, cfg *config.FederationAPI, rsAPI roomserverAPI.FederationRoomserverAPI, - federation api.FederationClient, + federation fclient.FederationClient, statistics *statistics.Statistics, caches *caching.Caches, queues *queue.OutgoingQueues, @@ -107,7 +109,7 @@ func NewFederationInternalAPI( } } -func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) { +func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s spec.ServerName) (*statistics.ServerStatistics, error) { stats := a.statistics.ForServer(s) if stats.Blacklisted() { return stats, &api.FederationClientError{ @@ -144,7 +146,7 @@ func failBlacklistableError(err error, stats *statistics.ServerStatistics) (unti } func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted( - s gomatrixserverlib.ServerName, request func() (interface{}, error), + s spec.ServerName, request func() (interface{}, error), ) (interface{}, error) { stats, err := a.isBlacklistedOrBackingOff(s) if err != nil { @@ -169,7 +171,7 @@ func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted( } func (a *FederationInternalAPI) doRequestIfNotBlacklisted( - s gomatrixserverlib.ServerName, request func() (interface{}, error), + s spec.ServerName, request func() (interface{}, error), ) (interface{}, error) { stats := a.statistics.ForServer(s) if blacklisted := stats.Blacklisted(); blacklisted { diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index b0d5b1d1f2..d4d7269dbb 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -6,16 +6,43 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) +const defaultTimeout = time.Second * 30 + // Functions here are "proxying" calls to the gomatrixserverlib federation // client. +func (a *FederationInternalAPI) MakeJoin( + ctx context.Context, origin, s spec.ServerName, roomID, userID string, +) (res gomatrixserverlib.MakeJoinResponse, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + ires, err := a.federation.MakeJoin(ctx, origin, s, roomID, userID) + if err != nil { + return &fclient.RespMakeJoin{}, err + } + return &ires, nil +} + +func (a *FederationInternalAPI) SendJoin( + ctx context.Context, origin, s spec.ServerName, event gomatrixserverlib.PDU, +) (res gomatrixserverlib.SendJoinResponse, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + ires, err := a.federation.SendJoin(ctx, origin, s, event) + if err != nil { + return &fclient.RespSendJoin{}, err + } + return &ires, nil +} + func (a *FederationInternalAPI) GetEventAuth( - ctx context.Context, origin, s gomatrixserverlib.ServerName, + ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (res fclient.RespEventAuth, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) @@ -27,9 +54,9 @@ func (a *FederationInternalAPI) GetEventAuth( } func (a *FederationInternalAPI) GetUserDevices( - ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, + ctx context.Context, origin, s spec.ServerName, userID string, ) (fclient.RespUserDevices, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetUserDevices(ctx, origin, s, userID) @@ -41,9 +68,9 @@ func (a *FederationInternalAPI) GetUserDevices( } func (a *FederationInternalAPI) ClaimKeys( - ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, + ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string, ) (fclient.RespClaimKeys, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) @@ -55,7 +82,7 @@ func (a *FederationInternalAPI) ClaimKeys( } func (a *FederationInternalAPI) QueryKeys( - ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, + ctx context.Context, origin, s spec.ServerName, keys map[string][]string, ) (fclient.RespQueryKeys, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { return a.federation.QueryKeys(ctx, origin, s, keys) @@ -67,9 +94,9 @@ func (a *FederationInternalAPI) QueryKeys( } func (a *FederationInternalAPI) Backfill( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s spec.ServerName, roomID string, limit int, eventIDs []string, ) (res gomatrixserverlib.Transaction, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) @@ -81,9 +108,9 @@ func (a *FederationInternalAPI) Backfill( } func (a *FederationInternalAPI) LookupState( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, + ctx context.Context, origin, s spec.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.StateResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) @@ -96,9 +123,9 @@ func (a *FederationInternalAPI) LookupState( } func (a *FederationInternalAPI) LookupStateIDs( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, + ctx context.Context, origin, s spec.ServerName, roomID, eventID string, ) (res gomatrixserverlib.StateIDResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) @@ -110,10 +137,10 @@ func (a *FederationInternalAPI) LookupStateIDs( } func (a *FederationInternalAPI) LookupMissingEvents( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, + ctx context.Context, origin, s spec.ServerName, roomID string, missing fclient.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res fclient.RespMissingEvents, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) @@ -125,9 +152,9 @@ func (a *FederationInternalAPI) LookupMissingEvents( } func (a *FederationInternalAPI) GetEvent( - ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, + ctx context.Context, origin, s spec.ServerName, eventID string, ) (res gomatrixserverlib.Transaction, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.GetEvent(ctx, origin, s, eventID) @@ -139,7 +166,7 @@ func (a *FederationInternalAPI) GetEvent( } func (a *FederationInternalAPI) LookupServerKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + ctx context.Context, s spec.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) ([]gomatrixserverlib.ServerKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() @@ -153,7 +180,7 @@ func (a *FederationInternalAPI) LookupServerKeys( } func (a *FederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, origin, s gomatrixserverlib.ServerName, r fclient.MSC2836EventRelationshipsRequest, + ctx context.Context, origin, s spec.ServerName, r fclient.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res fclient.MSC2836EventRelationshipsResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) @@ -168,7 +195,7 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, + ctx context.Context, origin, s spec.ServerName, roomID string, suggestedOnly bool, ) (res fclient.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() diff --git a/federationapi/internal/federationclient_test.go b/federationapi/internal/federationclient_test.go index 948a96eec1..8c562dd61e 100644 --- a/federationapi/internal/federationclient_test.go +++ b/federationapi/internal/federationclient_test.go @@ -24,8 +24,8 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -34,7 +34,7 @@ const ( FailuresUntilBlacklist = 8 ) -func (t *testFedClient) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (fclient.RespQueryKeys, error) { +func (t *testFedClient) QueryKeys(ctx context.Context, origin, s spec.ServerName, keys map[string][]string) (fclient.RespQueryKeys, error) { t.queryKeysCalled = true if t.shouldFail { return fclient.RespQueryKeys{}, fmt.Errorf("Failure") @@ -42,7 +42,7 @@ func (t *testFedClient) QueryKeys(ctx context.Context, origin, s gomatrixserverl return fclient.RespQueryKeys{}, nil } -func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (fclient.RespClaimKeys, error) { +func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s spec.ServerName, oneTimeKeys map[string]map[string]string) (fclient.RespClaimKeys, error) { t.claimKeysCalled = true if t.shouldFail { return fclient.RespClaimKeys{}, fmt.Errorf("Failure") diff --git a/federationapi/internal/keys.go b/federationapi/internal/keys.go index 258bd88bf8..a642f3a4bd 100644 --- a/federationapi/internal/keys.go +++ b/federationapi/internal/keys.go @@ -7,6 +7,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -31,14 +32,14 @@ func (s *FederationInternalAPI) StoreKeys( func (s *FederationInternalAPI) FetchKeys( _ context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { // Run in a background context - we don't want to stop this work just // because the caller gives up waiting. ctx := context.Background() - now := gomatrixserverlib.AsTimestamp(time.Now()) + now := spec.AsTimestamp(time.Now()) results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} - origRequests := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{} + origRequests := map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{} for k, v := range requests { origRequests[k] = v } @@ -95,7 +96,7 @@ func (s *FederationInternalAPI) FetcherName() string { // a request for our own server keys, either current or old. func (s *FederationInternalAPI) handleLocalKeys( _ context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) { for req := range requests { @@ -111,10 +112,10 @@ func (s *FederationInternalAPI) handleLocalKeys( // Insert our own key into the response. results[req] = gomatrixserverlib.PublicKeyLookupResult{ VerifyKey: gomatrixserverlib.VerifyKey{ - Key: gomatrixserverlib.Base64Bytes(s.cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)), + Key: spec.Base64Bytes(s.cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)), }, ExpiredTS: gomatrixserverlib.PublicKeyNotExpired, - ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(s.cfg.Matrix.KeyValidityPeriod)), + ValidUntilTS: spec.AsTimestamp(time.Now().Add(s.cfg.Matrix.KeyValidityPeriod)), } } else { // The key request doesn't match our current key. Let's see @@ -128,7 +129,7 @@ func (s *FederationInternalAPI) handleLocalKeys( // Insert our own key into the response. results[req] = gomatrixserverlib.PublicKeyLookupResult{ VerifyKey: gomatrixserverlib.VerifyKey{ - Key: gomatrixserverlib.Base64Bytes(oldVerifyKey.PrivateKey.Public().(ed25519.PublicKey)), + Key: spec.Base64Bytes(oldVerifyKey.PrivateKey.Public().(ed25519.PublicKey)), }, ExpiredTS: oldVerifyKey.ExpiredAt, ValidUntilTS: gomatrixserverlib.PublicKeyNotValid, @@ -146,8 +147,8 @@ func (s *FederationInternalAPI) handleLocalKeys( // satisfied from our local database/cache. func (s *FederationInternalAPI) handleDatabaseKeys( ctx context.Context, - now gomatrixserverlib.Timestamp, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + now spec.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) error { // Ask the database/cache for the keys. @@ -169,7 +170,7 @@ func (s *FederationInternalAPI) handleDatabaseKeys( // in that case. If the key isn't valid right now, then by // leaving it in the 'requests' map, we'll try to update the // key using the fetchers in handleFetcherKeys. - if res.WasValidAt(now, true) { + if res.WasValidAt(now, gomatrixserverlib.StrictValiditySignatureCheck) { delete(requests, req) } } @@ -180,9 +181,9 @@ func (s *FederationInternalAPI) handleDatabaseKeys( // the remaining requests. func (s *FederationInternalAPI) handleFetcherKeys( ctx context.Context, - _ gomatrixserverlib.Timestamp, + _ spec.Timestamp, fetcher gomatrixserverlib.KeyFetcher, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) error { logrus.WithFields(logrus.Fields{ diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 08287c6921..515b3377de 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -2,6 +2,7 @@ package internal import ( "context" + "crypto/ed25519" "encoding/json" "errors" "fmt" @@ -10,6 +11,7 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -17,6 +19,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/consumers" "github.com/matrix-org/dendrite/federationapi/statistics" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/version" ) @@ -72,17 +75,11 @@ func (r *FederationInternalAPI) PerformJoin( r.joins.Store(j, nil) defer r.joins.Delete(j) - // Look up the supported room versions. - var supportedVersions []gomatrixserverlib.RoomVersion - for version := range version.SupportedRoomVersions() { - supportedVersions = append(supportedVersions, version) - } - // Deduplicate the server names we were provided but keep the ordering // as this encodes useful information about which servers are most likely // to respond. - seenSet := make(map[gomatrixserverlib.ServerName]bool) - var uniqueList []gomatrixserverlib.ServerName + seenSet := make(map[spec.ServerName]bool) + var uniqueList []spec.ServerName for _, srv := range request.ServerNames { if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) { continue @@ -102,7 +99,6 @@ func (r *FederationInternalAPI) PerformJoin( request.UserID, request.Content, serverName, - supportedVersions, request.Unsigned, ); err != nil { logrus.WithError(err).WithFields(logrus.Fields{ @@ -144,128 +140,70 @@ func (r *FederationInternalAPI) performJoinUsingServer( ctx context.Context, roomID, userID string, content map[string]interface{}, - serverName gomatrixserverlib.ServerName, - supportedVersions []gomatrixserverlib.RoomVersion, + serverName spec.ServerName, unsigned map[string]interface{}, ) error { if !r.shouldAttemptDirectFederation(serverName) { return fmt.Errorf("relay servers have no meaningful response for join.") } - _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) + user, err := spec.NewUserID(userID, true) if err != nil { return err } - - // Try to perform a make_join using the information supplied in the - // request. - respMakeJoin, err := r.federation.MakeJoin( - ctx, - origin, - serverName, - roomID, - userID, - supportedVersions, - ) - if err != nil { - // TODO: Check if the user was not allowed to join the room. - r.statistics.ForServer(serverName).Failure() - return fmt.Errorf("r.federation.MakeJoin: %w", err) - } - r.statistics.ForServer(serverName).Success(statistics.SendDirect) - - // Set all the fields to be what they should be, this should be a no-op - // but it's possible that the remote server returned us something "odd" - respMakeJoin.JoinEvent.Type = gomatrixserverlib.MRoomMember - respMakeJoin.JoinEvent.Sender = userID - respMakeJoin.JoinEvent.StateKey = &userID - respMakeJoin.JoinEvent.RoomID = roomID - respMakeJoin.JoinEvent.Redacts = "" - if content == nil { - content = map[string]interface{}{} - } - _ = json.Unmarshal(respMakeJoin.JoinEvent.Content, &content) - content["membership"] = gomatrixserverlib.Join - if err = respMakeJoin.JoinEvent.SetContent(content); err != nil { - return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err) - } - if err = respMakeJoin.JoinEvent.SetUnsigned(struct{}{}); err != nil { - return fmt.Errorf("respMakeJoin.JoinEvent.SetUnsigned: %w", err) - } - - // Work out if we support the room version that has been supplied in - // the make_join response. - // "If not provided, the room version is assumed to be either "1" or "2"." - // https://matrix.org/docs/spec/server_server/unstable#get-matrix-federation-v1-make-join-roomid-userid - if respMakeJoin.RoomVersion == "" { - respMakeJoin.RoomVersion = setDefaultRoomVersionFromJoinEvent(respMakeJoin.JoinEvent) - } - if _, err = respMakeJoin.RoomVersion.EventFormat(); err != nil { - return fmt.Errorf("respMakeJoin.RoomVersion.EventFormat: %w", err) - } - - // Build the join event. - event, err := respMakeJoin.JoinEvent.Build( - time.Now(), - origin, - r.cfg.Matrix.KeyID, - r.cfg.Matrix.PrivateKey, - respMakeJoin.RoomVersion, - ) + room, err := spec.NewRoomID(roomID) if err != nil { - return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err) + return err } - // Try to perform a send_join using the newly built event. - respSendJoin, err := r.federation.SendJoin( - context.Background(), - origin, - serverName, - event, - ) - if err != nil { - r.statistics.ForServer(serverName).Failure() - return fmt.Errorf("r.federation.SendJoin: %w", err) + joinInput := gomatrixserverlib.PerformJoinInput{ + UserID: user, + RoomID: room, + ServerName: serverName, + Content: content, + Unsigned: unsigned, + PrivateKey: r.cfg.Matrix.PrivateKey, + KeyID: r.cfg.Matrix.KeyID, + KeyRing: r.keyRing, + EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, + GetOrCreateSenderID: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) { + // assign a roomNID, otherwise we can't create a private key for the user + _, nidErr := r.rsAPI.AssignRoomNID(ctx, roomID, gomatrixserverlib.RoomVersion(roomVersion)) + if nidErr != nil { + return "", nil, nidErr + } + key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if keyErr != nil { + return "", nil, keyErr + } + return spec.SenderIDFromPseudoIDKey(key), key, nil + }, + StoreSenderIDFromPublicID: func(ctx context.Context, senderID spec.SenderID, userIDRaw string, roomID spec.RoomID) error { + storeUserID, userErr := spec.NewUserID(userIDRaw, true) + if userErr != nil { + return userErr + } + return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID) + }, } - r.statistics.ForServer(serverName).Success(statistics.SendDirect) + response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) - // If the remote server returned an event in the "event" key of - // the send_join request then we should use that instead. It may - // contain signatures that we don't know about. - if len(respSendJoin.Event) > 0 { - var remoteEvent *gomatrixserverlib.Event - remoteEvent, err = respSendJoin.Event.UntrustedEvent(respMakeJoin.RoomVersion) - if err == nil && isWellFormedMembershipEvent( - remoteEvent, roomID, userID, - ) { - event = remoteEvent + if joinErr != nil { + if !joinErr.Reachable { + r.statistics.ForServer(joinErr.ServerName).Failure() + } else { + r.statistics.ForServer(joinErr.ServerName).Success(statistics.SendDirect) } + return joinErr.Err } - - // Sanity-check the join response to ensure that it has a create - // event, that the room version is known, etc. - authEvents := respSendJoin.AuthEvents.UntrustedEvents(respMakeJoin.RoomVersion) - if err = sanityCheckAuthChain(authEvents); err != nil { - return fmt.Errorf("sanityCheckAuthChain: %w", err) - } - - // Process the join response in a goroutine. The idea here is - // that we'll try and wait for as long as possible for the work - // to complete, but if the client does give up waiting, we'll - // still continue to process the join anyway so that we don't - // waste the effort. - // TODO: Can we expand Check here to return a list of missing auth - // events rather than failing one at a time? - var respState *fclient.RespState - respState, err = respSendJoin.Check( - context.Background(), - respMakeJoin.RoomVersion, - r.keyRing, - event, - federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName), - ) - if err != nil { - return fmt.Errorf("respSendJoin.Check: %w", err) + r.statistics.ForServer(serverName).Success(statistics.SendDirect) + if response == nil { + return fmt.Errorf("Received nil response from gomatrixserverlib.PerformJoin") } // We need to immediately update our list of joined hosts for this room now as we are technically @@ -274,60 +212,33 @@ func (r *FederationInternalAPI) performJoinUsingServer( // joining a room, waiting for 200 OK then changing device keys and have those keys not be sent // to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers") // The events are trusted now as we performed auth checks above. - joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false)) + joinedHosts, err := consumers.JoinedHostsFromEvents(ctx, response.StateSnapshot.GetStateEvents().TrustedEvents(response.JoinEvent.Version(), false), r.rsAPI) if err != nil { return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err) } + logrus.WithField("room", roomID).Infof("Joined federated room with %d hosts", len(joinedHosts)) if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil { return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err) } - // If we successfully performed a send_join above then the other - // server now thinks we're a part of the room. Send the newly - // returned state to the roomserver to update our local view. - if unsigned != nil { - event, err = event.SetUnsigned(unsigned) - if err != nil { - // non-fatal, log and continue - logrus.WithError(err).Errorf("Failed to set unsigned content") - } - } - + // TODO: Can I change this to not take respState but instead just take an opaque list of events? if err = roomserverAPI.SendEventWithState( context.Background(), r.rsAPI, - origin, + user.Domain(), roomserverAPI.KindNew, - respState, - event.Headered(respMakeJoin.RoomVersion), + response.StateSnapshot, + &types.HeaderedEvent{PDU: response.JoinEvent}, serverName, nil, false, ); err != nil { return fmt.Errorf("roomserverAPI.SendEventWithState: %w", err) } - return nil } -// isWellFormedMembershipEvent returns true if the event looks like a legitimate -// membership event. -func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID string) bool { - if membership, err := event.Membership(); err != nil { - return false - } else if membership != gomatrixserverlib.Join { - return false - } - if event.RoomID() != roomID { - return false - } - if !event.StateKeyEquals(userID) { - return false - } - return true -} - // PerformOutboundPeekRequest implements api.FederationInternalAPI func (r *FederationInternalAPI) PerformOutboundPeek( ctx context.Context, @@ -343,8 +254,8 @@ func (r *FederationInternalAPI) PerformOutboundPeek( // Deduplicate the server names we were provided but keep the ordering // as this encodes useful information about which servers are most likely // to respond. - seenSet := make(map[gomatrixserverlib.ServerName]bool) - var uniqueList []gomatrixserverlib.ServerName + seenSet := make(map[spec.ServerName]bool) + var uniqueList []spec.ServerName for _, srv := range request.ServerNames { if seenSet[srv] { continue @@ -410,7 +321,7 @@ func (r *FederationInternalAPI) PerformOutboundPeek( func (r *FederationInternalAPI) performOutboundPeekUsingServer( ctx context.Context, roomID string, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, supportedVersions []gomatrixserverlib.RoomVersion, ) error { if !r.shouldAttemptDirectFederation(serverName) { @@ -463,21 +374,25 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( if respPeek.RoomVersion == "" { respPeek.RoomVersion = gomatrixserverlib.RoomVersionV1 } - if _, err = respPeek.RoomVersion.EventFormat(); err != nil { - return fmt.Errorf("respPeek.RoomVersion.EventFormat: %w", err) + if !gomatrixserverlib.KnownRoomVersion(respPeek.RoomVersion) { + return fmt.Errorf("unknown room version: %s", respPeek.RoomVersion) } // we have the peek state now so let's process regardless of whether upstream gives up ctx = context.Background() - respState := respPeek.ToRespState() // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName)) + userIDProvider := func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + } + authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( + ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider, + ) if err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) } - if err = sanityCheckAuthChain(authEvents); err != nil { + if err = checkEventsContainCreateEvent(authEvents); err != nil { return fmt.Errorf("sanityCheckAuthChain: %w", err) } @@ -497,8 +412,12 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( if err = roomserverAPI.SendEventWithState( ctx, r.rsAPI, r.cfg.Matrix.ServerName, roomserverAPI.KindNew, - &respState, - respPeek.LatestEvent.Headered(respPeek.RoomVersion), + // use the authorized state from CheckStateResponse + &fclient.RespState{ + StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(stateEvents), + AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(authEvents), + }, + &types.HeaderedEvent{PDU: respPeek.LatestEvent}, serverName, nil, false, @@ -515,7 +434,7 @@ func (r *FederationInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) (err error) { - _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID) + userID, err := spec.NewUserID(request.UserID, true) if err != nil { return err } @@ -534,7 +453,7 @@ func (r *FederationInternalAPI) PerformLeave( // request. respMakeLeave, err := r.federation.MakeLeave( ctx, - origin, + userID.Domain(), serverName, request.RoomID, request.UserID, @@ -546,40 +465,51 @@ func (r *FederationInternalAPI) PerformLeave( continue } + // Work out if we support the room version that has been supplied in + // the make_leave response. + verImpl, err := gomatrixserverlib.GetRoomVersion(respMakeLeave.RoomVersion) + if err != nil { + return err + } + // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" - respMakeLeave.LeaveEvent.Type = gomatrixserverlib.MRoomMember - respMakeLeave.LeaveEvent.Sender = request.UserID - respMakeLeave.LeaveEvent.StateKey = &request.UserID + roomID, err := spec.NewRoomID(request.RoomID) + if err != nil { + return err + } + senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) + if err != nil { + return err + } + senderIDString := string(senderID) + respMakeLeave.LeaveEvent.Type = spec.MRoomMember + respMakeLeave.LeaveEvent.SenderID = senderIDString + respMakeLeave.LeaveEvent.StateKey = &senderIDString respMakeLeave.LeaveEvent.RoomID = request.RoomID respMakeLeave.LeaveEvent.Redacts = "" + leaveEB := verImpl.NewEventBuilderFromProtoEvent(&respMakeLeave.LeaveEvent) + if respMakeLeave.LeaveEvent.Content == nil { content := map[string]interface{}{ "membership": "leave", } - if err = respMakeLeave.LeaveEvent.SetContent(content); err != nil { + if err = leaveEB.SetContent(content); err != nil { logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.SetContent failed") continue } } - if err = respMakeLeave.LeaveEvent.SetUnsigned(struct{}{}); err != nil { + if err = leaveEB.SetUnsigned(struct{}{}); err != nil { logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.SetUnsigned failed") continue } - // Work out if we support the room version that has been supplied in - // the make_leave response. - if _, err = respMakeLeave.RoomVersion.EventFormat(); err != nil { - return gomatrixserverlib.UnsupportedRoomVersionError{} - } - // Build the leave event. - event, err := respMakeLeave.LeaveEvent.Build( + event, err := leaveEB.Build( time.Now(), - origin, + userID.Domain(), r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, - respMakeLeave.RoomVersion, ) if err != nil { logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.Build failed") @@ -589,7 +519,7 @@ func (r *FederationInternalAPI) PerformLeave( // Try to perform a send_leave using the newly built event. err = r.federation.SendLeave( ctx, - origin, + userID.Domain(), serverName, event, ) @@ -610,56 +540,63 @@ func (r *FederationInternalAPI) PerformLeave( ) } -// PerformLeaveRequest implements api.FederationInternalAPI -func (r *FederationInternalAPI) PerformInvite( +// SendInvite implements api.FederationInternalAPI +func (r *FederationInternalAPI) SendInvite( ctx context.Context, - request *api.PerformInviteRequest, - response *api.PerformInviteResponse, -) (err error) { - _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.Event.Sender()) + event gomatrixserverlib.PDU, + strippedState []gomatrixserverlib.InviteStrippedState, +) (gomatrixserverlib.PDU, error) { + validRoomID, err := spec.NewRoomID(event.RoomID()) if err != nil { - return err + return nil, err + } + inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + if err != nil { + return nil, err } - if request.Event.StateKey() == nil { - return errors.New("invite must be a state event") + if event.StateKey() == nil { + return nil, errors.New("invite must be a state event") } - _, destination, err := gomatrixserverlib.SplitID('@', *request.Event.StateKey()) + _, destination, err := gomatrixserverlib.SplitID('@', *event.StateKey()) if err != nil { - return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } // TODO (devon): This should be allowed via a relay. Currently only transactions // can be sent to relays. Would need to extend relays to handle invites. if !r.shouldAttemptDirectFederation(destination) { - return fmt.Errorf("relay servers have no meaningful response for invite.") + return nil, fmt.Errorf("relay servers have no meaningful response for invite.") } logrus.WithFields(logrus.Fields{ - "event_id": request.Event.EventID(), - "user_id": *request.Event.StateKey(), - "room_id": request.Event.RoomID(), - "room_version": request.RoomVersion, + "event_id": event.EventID(), + "user_id": *event.StateKey(), + "room_id": event.RoomID(), + "room_version": event.Version(), "destination": destination, }).Info("Sending invite") - inviteReq, err := gomatrixserverlib.NewInviteV2Request(request.Event, request.InviteRoomState) + inviteReq, err := fclient.NewInviteV2Request(event, strippedState) if err != nil { - return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) + return nil, fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } - inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) + inviteRes, err := r.federation.SendInviteV2(ctx, inviter.Domain(), destination, inviteReq) if err != nil { - return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) + return nil, fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) + } + verImpl, err := gomatrixserverlib.GetRoomVersion(event.Version()) + if err != nil { + return nil, err } - inviteEvent, err := inviteRes.Event.UntrustedEvent(request.RoomVersion) + inviteEvent, err := verImpl.NewEventFromUntrustedJSON(inviteRes.Event) if err != nil { - return fmt.Errorf("r.federation.SendInviteV2 failed to decode event response: %w", err) + return nil, fmt.Errorf("r.federation.SendInviteV2 failed to decode event response: %w", err) } - response.Event = inviteEvent.Headered(request.RoomVersion) - return nil + return inviteEvent, nil } // PerformServersAlive implements api.FederationInternalAPI @@ -700,17 +637,17 @@ func (r *FederationInternalAPI) PerformWakeupServers( return nil } -func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { +func (r *FederationInternalAPI) MarkServersAlive(destinations []spec.ServerName) { for _, srv := range destinations { wasBlacklisted := r.statistics.ForServer(srv).MarkServerAlive() r.queues.RetryServer(srv, wasBlacklisted) } } -func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { +func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error { // sanity check we have a create event and it has a known room version - for _, ev := range authChain { - if ev.Type() == gomatrixserverlib.MRoomCreate && ev.StateKeyEquals("") { + for _, ev := range events { + if ev.Type() == spec.MRoomCreate && ev.StateKeyEquals("") { // make sure the room version is known content := ev.Content() verBody := struct { @@ -727,52 +664,33 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { } knownVersions := gomatrixserverlib.RoomVersions() if _, ok := knownVersions[gomatrixserverlib.RoomVersion(verBody.Version)]; !ok { - return fmt.Errorf("auth chain m.room.create event has an unknown room version: %s", verBody.Version) + return fmt.Errorf("m.room.create event has an unknown room version: %s", verBody.Version) } return nil } } - return fmt.Errorf("auth chain response is missing m.room.create event") + return fmt.Errorf("response is missing m.room.create event") } -func setDefaultRoomVersionFromJoinEvent( - joinEvent gomatrixserverlib.EventBuilder, -) gomatrixserverlib.RoomVersion { - // if auth events are not event references we know it must be v3+ - // we have to do these shenanigans to satisfy sytest, specifically for: - // "Outbound federation rejects m.room.create events with an unknown room version" - hasEventRefs := true - authEvents, ok := joinEvent.AuthEvents.([]interface{}) - if ok { - if len(authEvents) > 0 { - _, ok = authEvents[0].(string) - if ok { - // event refs are objects, not strings, so we know we must be dealing with a v3+ room. - hasEventRefs = false - } - } - } - - if hasEventRefs { - return gomatrixserverlib.RoomVersionV1 - } - return gomatrixserverlib.RoomVersionV4 -} - -// FederatedAuthProvider is an auth chain provider which fetches events from the server provided -func federatedAuthProvider( - ctx context.Context, federation api.FederationClient, - keyRing gomatrixserverlib.JSONVerifier, origin, server gomatrixserverlib.ServerName, -) gomatrixserverlib.AuthChainProvider { +// federatedEventProvider is an event provider which fetches events from the server provided +func federatedEventProvider( + ctx context.Context, federation fclient.FederationClient, + keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName, + userIDForSender spec.UserIDForSender, +) gomatrixserverlib.EventProvider { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. - retries := map[string][]*gomatrixserverlib.Event{} + retries := map[string][]gomatrixserverlib.PDU{} // Define a function which we can pass to Check to retrieve missing // auth events inline. This greatly increases our chances of not having // to repeat the entire set of checks just for a missing event or two. - return func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) { - returning := []*gomatrixserverlib.Event{} + return func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.PDU, error) { + returning := []gomatrixserverlib.PDU{} + verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) + if err != nil { + return nil, err + } // See if we have retry entries for each of the supplied event IDs. for _, eventID := range eventIDs { @@ -802,13 +720,13 @@ func federatedAuthProvider( // event ID again. for _, pdu := range tx.PDUs { // Try to parse the event. - ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + ev, everr := verImpl.NewEventFromUntrustedJSON(pdu) if everr != nil { return nil, fmt.Errorf("missingAuth gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr) } // Check the signatures of the event. - if err := ev.VerifyEventSignatures(ctx, keyRing); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil { return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) } @@ -868,7 +786,7 @@ func (r *FederationInternalAPI) P2PRemoveRelayServers( } func (r *FederationInternalAPI) shouldAttemptDirectFederation( - destination gomatrixserverlib.ServerName, + destination spec.ServerName, ) bool { var shouldRelay bool stats := r.statistics.ForServer(destination) diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go index 90849dcf62..2f61235aec 100644 --- a/federationapi/internal/perform_test.go +++ b/federationapi/internal/perform_test.go @@ -24,26 +24,26 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) type testFedClient struct { - api.FederationClient + fclient.FederationClient queryKeysCalled bool claimKeysCalled bool shouldFail bool } -func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res fclient.RespDirectory, err error) { +func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s spec.ServerName, roomAlias string) (res fclient.RespDirectory, err error) { return fclient.RespDirectory{}, nil } func TestPerformWakeupServers(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() - server := gomatrixserverlib.ServerName("wakeup") + server := spec.ServerName("wakeup") testDB.AddServerToBlacklist(server) testDB.SetServerAssumedOffline(context.Background(), server) blacklisted, err := testDB.IsServerBlacklisted(server) @@ -73,7 +73,7 @@ func TestPerformWakeupServers(t *testing.T) { ) req := api.PerformWakeupServersRequest{ - ServerNames: []gomatrixserverlib.ServerName{server}, + ServerNames: []spec.ServerName{server}, } res := api.PerformWakeupServersResponse{} err = fedAPI.PerformWakeupServers(context.Background(), &req, &res) @@ -90,8 +90,8 @@ func TestPerformWakeupServers(t *testing.T) { func TestQueryRelayServers(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() - server := gomatrixserverlib.ServerName("wakeup") - relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + server := spec.ServerName("wakeup") + relayServers := []spec.ServerName{"relay1", "relay2"} err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) assert.NoError(t, err) @@ -127,8 +127,8 @@ func TestQueryRelayServers(t *testing.T) { func TestRemoveRelayServers(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() - server := gomatrixserverlib.ServerName("wakeup") - relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + server := spec.ServerName("wakeup") + relayServers := []spec.ServerName{"relay1", "relay2"} err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) assert.NoError(t, err) @@ -153,7 +153,7 @@ func TestRemoveRelayServers(t *testing.T) { req := api.P2PRemoveRelayServersRequest{ Server: server, - RelayServers: []gomatrixserverlib.ServerName{"relay1"}, + RelayServers: []spec.ServerName{"relay1"}, } res := api.P2PRemoveRelayServersResponse{} err = fedAPI.P2PRemoveRelayServers(context.Background(), &req, &res) @@ -162,7 +162,7 @@ func TestRemoveRelayServers(t *testing.T) { finalRelays, err := testDB.P2PGetRelayServersForServer(context.Background(), server) assert.NoError(t, err) assert.Equal(t, 1, len(finalRelays)) - assert.Equal(t, gomatrixserverlib.ServerName("relay2"), finalRelays[0]) + assert.Equal(t, spec.ServerName("relay2"), finalRelays[0]) } func TestPerformDirectoryLookup(t *testing.T) { @@ -199,9 +199,9 @@ func TestPerformDirectoryLookup(t *testing.T) { func TestPerformDirectoryLookupRelaying(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() - server := gomatrixserverlib.ServerName("wakeup") + server := spec.ServerName("wakeup") testDB.SetServerAssumedOffline(context.Background(), server) - testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"}) + testDB.P2PAddRelayServersForServer(context.Background(), server, []spec.ServerName{"relay"}) cfg := config.FederationAPI{ Matrix: &config.Global{ diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index 688afa8eae..e53f19ff8c 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -7,6 +7,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -25,7 +26,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom( return } -func (a *FederationInternalAPI) fetchServerKeysDirectly(ctx context.Context, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { +func (a *FederationInternalAPI) fetchServerKeysDirectly(ctx context.Context, serverName spec.ServerName) (*gomatrixserverlib.ServerKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBackingOffOrBlacklisted(serverName, func() (interface{}, error) { diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 6bcfafa39d..ede56436a5 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -22,6 +22,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -46,7 +47,7 @@ type SyncAPIProducer struct { func (p *SyncAPIProducer) SendReceipt( ctx context.Context, - userID, roomID, eventID, receiptType string, timestamp gomatrixserverlib.Timestamp, + userID, roomID, eventID, receiptType string, timestamp spec.Timestamp, ) error { m := &nats.Msg{ Subject: p.TopicReceiptEvent, @@ -155,7 +156,7 @@ func (p *SyncAPIProducer) SendPresence( if statusMsg != nil { m.Header.Set("status_msg", *statusMsg) } - lastActiveTS := gomatrixserverlib.AsTimestamp(time.Now().Add(-(time.Duration(lastActiveAgo) * time.Millisecond))) + lastActiveTS := spec.AsTimestamp(time.Now().Add(-(time.Duration(lastActiveAgo) * time.Millisecond))) m.Header.Set("last_active_ts", strconv.Itoa(int(lastActiveTS))) log.Tracef("Sending presence to syncAPI: %+v", m.Header) @@ -164,7 +165,7 @@ func (p *SyncAPIProducer) SendPresence( } func (p *SyncAPIProducer) SendDeviceListUpdate( - ctx context.Context, deviceListUpdate gomatrixserverlib.RawJSON, origin gomatrixserverlib.ServerName, + ctx context.Context, deviceListUpdate spec.RawJSON, origin spec.ServerName, ) (err error) { m := nats.NewMsg(p.TopicDeviceListUpdate) m.Header.Set("origin", string(origin)) @@ -175,7 +176,7 @@ func (p *SyncAPIProducer) SendDeviceListUpdate( } func (p *SyncAPIProducer) SendSigningKeyUpdate( - ctx context.Context, data gomatrixserverlib.RawJSON, origin gomatrixserverlib.ServerName, + ctx context.Context, data spec.RawJSON, origin spec.ServerName, ) (err error) { m := nats.NewMsg(p.TopicSigningKeyUpdate) m.Header.Set("origin", string(origin)) diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index a4542c4985..880aee0d35 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -24,14 +24,15 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "go.uber.org/atomic" - fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/process" ) @@ -51,11 +52,11 @@ type destinationQueue struct { queues *OutgoingQueues db storage.Database process *process.ProcessContext - signing map[gomatrixserverlib.ServerName]*fclient.SigningIdentity + signing map[spec.ServerName]*fclient.SigningIdentity rsAPI api.FederationRoomserverAPI - client fedapi.FederationClient // federation client - origin gomatrixserverlib.ServerName // origin of requests - destination gomatrixserverlib.ServerName // destination of requests + client fclient.FederationClient // federation client + origin spec.ServerName // origin of requests + destination spec.ServerName // destination of requests running atomic.Bool // is the queue worker running? backingOff atomic.Bool // true if we're backing off overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more @@ -71,7 +72,7 @@ type destinationQueue struct { // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) { +func (oq *destinationQueue) sendEvent(event *types.HeaderedEvent, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return @@ -426,7 +427,7 @@ func (oq *destinationQueue) nextTransaction( relaySuccess := false logrus.Infof("Sending %q to relay servers: %v", t.TransactionID, relayServers) // TODO : how to pass through actual userID here?!?!?!?! - userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) + userID, userErr := spec.NewUserID("@user:"+string(oq.destination), false) if userErr != nil { return userErr, sendMethod } @@ -507,7 +508,7 @@ func (oq *destinationQueue) createTransaction( // it so that we retry with the same transaction ID. oq.transactionIDMutex.Lock() if oq.transactionID == "" { - now := gomatrixserverlib.AsTimestamp(time.Now()) + now := spec.AsTimestamp(time.Now()) oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) } oq.transactionIDMutex.Unlock() @@ -518,7 +519,7 @@ func (oq *destinationQueue) createTransaction( } t.Origin = oq.origin t.Destination = oq.destination - t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) + t.OriginServerTS = spec.AsTimestamp(time.Now()) t.TransactionID = oq.transactionID var pduReceipts []*receipt.Receipt diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index c0ecb28759..26305ed7aa 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -23,16 +23,17 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/process" ) @@ -43,12 +44,12 @@ type OutgoingQueues struct { process *process.ProcessContext disabled bool rsAPI api.FederationRoomserverAPI - origin gomatrixserverlib.ServerName - client fedapi.FederationClient + origin spec.ServerName + client fclient.FederationClient statistics *statistics.Statistics - signing map[gomatrixserverlib.ServerName]*fclient.SigningIdentity + signing map[spec.ServerName]*fclient.SigningIdentity queuesMutex sync.Mutex // protects the below - queues map[gomatrixserverlib.ServerName]*destinationQueue + queues map[spec.ServerName]*destinationQueue } func init() { @@ -87,8 +88,8 @@ func NewOutgoingQueues( db storage.Database, process *process.ProcessContext, disabled bool, - origin gomatrixserverlib.ServerName, - client fedapi.FederationClient, + origin spec.ServerName, + client fclient.FederationClient, rsAPI api.FederationRoomserverAPI, statistics *statistics.Statistics, signing []*fclient.SigningIdentity, @@ -101,15 +102,15 @@ func NewOutgoingQueues( origin: origin, client: client, statistics: statistics, - signing: map[gomatrixserverlib.ServerName]*fclient.SigningIdentity{}, - queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, + signing: map[spec.ServerName]*fclient.SigningIdentity{}, + queues: map[spec.ServerName]*destinationQueue{}, } for _, identity := range signing { queues.signing[identity.ServerName] = identity } // Look up which servers we have pending items for and then rehydrate those queues. if !disabled { - serverNames := map[gomatrixserverlib.ServerName]struct{}{} + serverNames := map[spec.ServerName]struct{}{} if names, err := db.GetPendingPDUServerNames(process.Context()); err == nil { for _, serverName := range names { serverNames[serverName] = struct{}{} @@ -140,7 +141,7 @@ func NewOutgoingQueues( type queuedPDU struct { dbReceipt *receipt.Receipt - pdu *gomatrixserverlib.HeaderedEvent + pdu *types.HeaderedEvent } type queuedEDU struct { @@ -148,7 +149,7 @@ type queuedEDU struct { edu *gomatrixserverlib.EDU } -func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { +func (oqs *OutgoingQueues) getQueue(destination spec.ServerName) *destinationQueue { if oqs.statistics.ForServer(destination).Blacklisted() { return nil } @@ -187,8 +188,8 @@ func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) { // SendEvent sends an event to the destinations func (oqs *OutgoingQueues) SendEvent( - ev *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, - destinations []gomatrixserverlib.ServerName, + ev *types.HeaderedEvent, origin spec.ServerName, + destinations []spec.ServerName, ) error { if oqs.disabled { log.Trace("Federation is disabled, not sending event") @@ -203,7 +204,7 @@ func (oqs *OutgoingQueues) SendEvent( // Deduplicate destinations and remove the origin from the list of // destinations just to be sure. - destmap := map[gomatrixserverlib.ServerName]struct{}{} + destmap := map[spec.ServerName]struct{}{} for _, d := range destinations { destmap[d] = struct{}{} } @@ -277,8 +278,8 @@ func (oqs *OutgoingQueues) SendEvent( // SendEDU sends an EDU event to the destinations. func (oqs *OutgoingQueues) SendEDU( - e *gomatrixserverlib.EDU, origin gomatrixserverlib.ServerName, - destinations []gomatrixserverlib.ServerName, + e *gomatrixserverlib.EDU, origin spec.ServerName, + destinations []spec.ServerName, ) error { if oqs.disabled { log.Trace("Federation is disabled, not sending EDU") @@ -293,7 +294,7 @@ func (oqs *OutgoingQueues) SendEDU( // Deduplicate destinations and remove the origin from the list of // destinations just to be sure. - destmap := map[gomatrixserverlib.ServerName]struct{}{} + destmap := map[spec.ServerName]struct{}{} for _, d := range destinations { destmap[d] = struct{}{} } @@ -376,7 +377,7 @@ func (oqs *OutgoingQueues) SendEDU( } // RetryServer attempts to resend events to the given server if we had given up. -func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName, wasBlacklisted bool) { +func (oqs *OutgoingQueues) RetryServer(srv spec.ServerName, wasBlacklisted bool) { if oqs.disabled { return } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 65a925d348..cc38e136ff 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -25,16 +25,17 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "go.uber.org/atomic" "gotest.tools/v3/poll" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" - "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" @@ -74,7 +75,7 @@ func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Cont } type stubFederationClient struct { - api.FederationClient + fclient.FederationClient shouldTxSucceed bool shouldTxRelaySucceed bool txCount atomic.Uint32 @@ -91,7 +92,7 @@ func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixse return fclient.RespSend{}, result } -func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res fclient.EmptyResp, err error) { +func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u spec.UserID, t gomatrixserverlib.Transaction, forwardingServer spec.ServerName) (res fclient.EmptyResp, err error) { var result error if !f.shouldTxRelaySucceed { result = fmt.Errorf("relay transaction failed") @@ -101,19 +102,19 @@ func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u return fclient.EmptyResp{}, result } -func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { +func mustCreatePDU(t *testing.T) *types.HeaderedEvent { t.Helper() content := `{"type":"m.room.message"}` - ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10) + ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV10).NewEventFromTrustedJSON([]byte(content), false) if err != nil { t.Fatalf("failed to create event: %v", err) } - return ev.Headered(gomatrixserverlib.RoomVersionV10) + return &types.HeaderedEvent{PDU: ev} } func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { t.Helper() - return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} + return &gomatrixserverlib.EDU{Type: spec.MTyping} } func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32, shouldTxSucceed bool, shouldTxRelaySucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { @@ -143,7 +144,7 @@ func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32 func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -152,7 +153,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { }() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -172,7 +173,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -181,7 +182,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { }() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -201,7 +202,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { func TestSendPDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -210,7 +211,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) { }() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -231,7 +232,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) { func TestSendEDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -240,7 +241,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) { }() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -261,7 +262,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) { func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -270,7 +271,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { }() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -289,7 +290,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { fc.shouldTxSucceed = true ev = mustCreatePDU(t) - err = queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err = queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) pollEnd := time.Now().Add(1 * time.Second) @@ -312,7 +313,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -321,7 +322,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { }() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -340,7 +341,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { fc.shouldTxSucceed = true ev = mustCreateEDU(t) - err = queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err = queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) pollEnd := time.Now().Add(1 * time.Second) @@ -363,7 +364,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -372,7 +373,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { }() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -395,7 +396,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -404,7 +405,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { }() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -427,7 +428,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -438,7 +439,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { queues.statistics.ForServer(destination).Failure() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -461,7 +462,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -472,7 +473,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { queues.statistics.ForServer(destination).Failure() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -495,7 +496,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { func TestRetryServerSendsPDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -507,7 +508,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { // before it is blacklisted and deleted. dest := queues.getQueue(destination) ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) checkBlacklisted := func(log poll.LogT) poll.Result { @@ -546,7 +547,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { func TestRetryServerSendsEDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -558,7 +559,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { // before it is blacklisted and deleted. dest := queues.getQueue(destination) ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) checkBlacklisted := func(log poll.LogT) poll.Result { @@ -597,7 +598,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { func TestSendPDUBatches(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) @@ -608,7 +609,7 @@ func TestSendPDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() - destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + destinations := map[spec.ServerName]struct{}{destination: {}} // Populate database with > maxPDUsPerTransaction pduMultiplier := uint32(3) for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ { @@ -620,7 +621,7 @@ func TestSendPDUBatches(t *testing.T) { } ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -641,7 +642,7 @@ func TestSendPDUBatches(t *testing.T) { func TestSendEDUBatches(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) @@ -652,7 +653,7 @@ func TestSendEDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() - destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + destinations := map[spec.ServerName]struct{}{destination: {}} // Populate database with > maxEDUsPerTransaction eduMultiplier := uint32(3) for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ { @@ -664,7 +665,7 @@ func TestSendEDUBatches(t *testing.T) { } ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -685,7 +686,7 @@ func TestSendEDUBatches(t *testing.T) { func TestSendPDUAndEDUBatches(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) @@ -696,7 +697,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() - destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + destinations := map[spec.ServerName]struct{}{destination: {}} // Populate database with > maxEDUsPerTransaction multiplier := uint32(3) for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ { @@ -716,7 +717,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { } ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -739,7 +740,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -749,7 +750,7 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { dest := queues.getQueue(destination) queues.statistics.ForServer(destination).Failure() - destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + destinations := map[spec.ServerName]struct{}{destination: {}} ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) @@ -775,8 +776,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { // NOTE : Only one test case against real databases can be run at a time. t.Parallel() failuresUntilBlacklist := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") - destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} + destination := spec.ServerName("remotehost") + destinations := map[spec.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. @@ -790,7 +791,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { // before it is blacklisted and deleted. dest := queues.getQueue(destination) ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) // NOTE : The server can be blacklisted before this, so manually inject the event @@ -843,7 +844,7 @@ func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(7) failuresUntilAssumedOffline := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -852,7 +853,7 @@ func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { }() ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -876,7 +877,7 @@ func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(7) failuresUntilAssumedOffline := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -885,7 +886,7 @@ func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) { }() ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -909,7 +910,7 @@ func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) failuresUntilAssumedOffline := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -917,11 +918,11 @@ func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { <-pc.WaitForShutdown() }() - relayServers := []gomatrixserverlib.ServerName{"relayserver"} + relayServers := []spec.ServerName{"relayserver"} queues.statistics.ForServer(destination).AddRelayServers(relayServers) ev := mustCreatePDU(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEvent(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { @@ -948,7 +949,7 @@ func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) failuresUntilAssumedOffline := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") + destination := spec.ServerName("remotehost") db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) defer close() defer func() { @@ -956,11 +957,11 @@ func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { <-pc.WaitForShutdown() }() - relayServers := []gomatrixserverlib.ServerName{"relayserver"} + relayServers := []spec.ServerName{"relayserver"} queues.statistics.ForServer(destination).AddRelayServers(relayServers) ev := mustCreateEDU(t) - err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + err := queues.SendEDU(ev, "localhost", []spec.ServerName{destination}) assert.NoError(t, err) check := func(log poll.LogT) poll.Result { diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 272f5e9d8c..552c4eac20 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -21,10 +21,12 @@ import ( "strconv" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -32,7 +34,7 @@ import ( // https://matrix.org/docs/spec/server_server/unstable.html#get-matrix-federation-v1-backfill-roomid func Backfill( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, cfg *config.FederationAPI, @@ -47,7 +49,7 @@ func Backfill( if _, _, err = gomatrixserverlib.SplitID('!', roomID); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Bad room ID: " + err.Error()), + JSON: spec.MissingParam("Bad room ID: " + err.Error()), } } @@ -62,14 +64,14 @@ func Backfill( if !exists { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("v is missing"), + JSON: spec.MissingParam("v is missing"), } } limit = httpReq.URL.Query().Get("limit") if len(limit) == 0 { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("limit is missing"), + JSON: spec.MissingParam("limit is missing"), } } @@ -89,23 +91,26 @@ func Backfill( util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("limit %q is invalid format", limit)), + JSON: spec.InvalidParam(fmt.Sprintf("limit %q is invalid format", limit)), } } - // Query the roomserver. + // Query the Roomserver. if err = rsAPI.PerformBackfill(httpReq.Context(), &req, &res); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("query.PerformBackfill failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Filter any event that's not from the requested room out. - evs := make([]*gomatrixserverlib.Event, 0) + evs := make([]gomatrixserverlib.PDU, 0) - var ev *gomatrixserverlib.HeaderedEvent + var ev *types.HeaderedEvent for _, ev = range res.Events { if ev.RoomID() == roomID { - evs = append(evs, ev.Event) + evs = append(evs, ev.PDU) } } @@ -126,7 +131,7 @@ func Backfill( txn := gomatrixserverlib.Transaction{ Origin: request.Destination(), PDUs: eventJSONs, - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + OriginServerTS: spec.AsTimestamp(time.Now()), } // Send the events to the client. diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index aae1299fe1..a54ff0d9cd 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -16,10 +16,10 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -38,7 +38,10 @@ func GetUserDevices( } if res.Error != nil { util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } sigReq := &api.QuerySignaturesRequest{ @@ -50,9 +53,7 @@ func GetUserDevices( for _, dev := range res.Devices { sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID)) } - if err := keyAPI.QuerySignatures(req.Context(), sigReq, sigRes); err != nil { - return jsonerror.InternalAPIError(req.Context(), err) - } + keyAPI.QuerySignatures(req.Context(), sigReq, sigRes) response := fclient.RespUserDevices{ UserID: userID, @@ -91,10 +92,10 @@ func GetUserDevices( for sourceUserID, forSourceUser := range targetKey { for sourceKeyID, sourceKey := range forSourceUser { if device.Keys.Signatures == nil { - device.Keys.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + device.Keys.Signatures = map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } if _, ok := device.Keys.Signatures[sourceUserID]; !ok { - device.Keys.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + device.Keys.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } device.Keys.Signatures[sourceUserID][sourceKeyID] = sourceKey } diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index 65a2a9bc8a..c26aa2f15b 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -16,17 +16,17 @@ import ( "context" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) // GetEventAuth returns event auth for the roomID and eventID func GetEventAuth( ctx context.Context, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, eventID string, @@ -43,9 +43,9 @@ func GetEventAuth( } if event.RoomID() != roomID { - return util.JSONResponse{Code: http.StatusNotFound, JSON: jsonerror.NotFound("event does not belong to this room")} + return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} } - resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) + resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) if resErr != nil { return *resErr } @@ -72,7 +72,7 @@ func GetEventAuth( return util.JSONResponse{ Code: http.StatusOK, JSON: fclient.RespEventAuth{ - AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), + AuthEvents: types.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), }, } } diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index b41292415d..d3f0e81c32 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -20,25 +20,21 @@ import ( "net/http" "time" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/roomserver/api" ) // GetEvent returns the requested event func GetEvent( ctx context.Context, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, eventID string, - origin gomatrixserverlib.ServerName, + origin spec.ServerName, ) util.JSONResponse { - err := allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) - if err != nil { - return *err - } // /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string, // which results in `QueryEventsByID` to first get the event and use that to determine the roomID. event, err := fetchEvent(ctx, rsAPI, "", eventID) @@ -46,9 +42,14 @@ func GetEvent( return *err } + err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) + if err != nil { + return *err + } + return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{ Origin: origin, - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + OriginServerTS: spec.AsTimestamp(time.Now()), PDUs: []json.RawMessage{ event.JSON(), }, @@ -59,11 +60,12 @@ func GetEvent( // otherwise it returns an error response which can be sent to the client. func allowedToSeeEvent( ctx context.Context, - origin gomatrixserverlib.ServerName, + origin spec.ServerName, rsAPI api.FederationRoomserverAPI, eventID string, + roomID string, ) *util.JSONResponse { - allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID) + allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID, roomID) if err != nil { resErr := util.ErrorResponse(err) return &resErr @@ -78,7 +80,7 @@ func allowedToSeeEvent( } // fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found. -func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { +func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, eventID string) (gomatrixserverlib.PDU, *util.JSONResponse) { var eventsResponse api.QueryEventsByIDResponse err := rsAPI.QueryEventsByID( ctx, @@ -93,9 +95,9 @@ func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, if len(eventsResponse.Events) == 0 { return nil, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Event not found"), + JSON: spec.NotFound("Event not found"), } } - return eventsResponse.Events[0].Event, nil + return eventsResponse.Events[0].PDU, nil } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index c1fdf266b3..e45209a2fd 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -20,48 +20,97 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" - roomserverVersion "github.com/matrix-org/dendrite/roomserver/version" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) // InviteV2 implements /_matrix/federation/v2/invite/{roomID}/{eventID} func InviteV2( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, - roomID string, + request *fclient.FederationRequest, + roomID spec.RoomID, eventID string, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, ) util.JSONResponse { - inviteReq := gomatrixserverlib.InviteV2Request{} + inviteReq := fclient.InviteV2Request{} err := json.Unmarshal(request.Content(), &inviteReq) switch e := err.(type) { case gomatrixserverlib.UnsupportedRoomVersionError: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion( + JSON: spec.UnsupportedRoomVersion( fmt.Sprintf("Room version %q is not supported by this server.", e.Version), ), } case gomatrixserverlib.BadJSONError: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } case nil: - return processInvite( - httpReq.Context(), true, inviteReq.Event(), inviteReq.RoomVersion(), inviteReq.InviteRoomState(), roomID, eventID, cfg, rsAPI, keys, - ) + if inviteReq.Event().StateKey() == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("The invite event has no state key"), + } + } + + invitedUser, userErr := spec.NewUserID(*inviteReq.Event().StateKey(), true) + if userErr != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("The user ID is invalid"), + } + } + if !cfg.Matrix.IsLocalServerName(invitedUser.Domain()) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("The invited user domain does not belong to this server"), + } + } + + if inviteReq.Event().EventID() != eventID { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("The event ID in the request path must match the event ID in the invite event JSON"), + } + } + + input := gomatrixserverlib.HandleInviteInput{ + RoomVersion: inviteReq.RoomVersion(), + RoomID: roomID, + InvitedUser: *invitedUser, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + RoomQuerier: rsAPI, + MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, + StateQuerier: rsAPI.StateQuerier(), + InviteEvent: inviteReq.Event(), + StrippedState: inviteReq.InviteRoomState(), + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, + } + event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) + if jsonErr != nil { + return *jsonErr + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: fclient.RespInviteV2{Event: event.JSON()}, + } default: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into an invite request. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into an invite request. " + err.Error()), } } } @@ -69,8 +118,8 @@ func InviteV2( // InviteV1 implements /_matrix/federation/v1/invite/{roomID}/{eventID} func InviteV1( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, - roomID string, + request *fclient.FederationRequest, + roomID spec.RoomID, eventID string, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, @@ -78,153 +127,122 @@ func InviteV1( ) util.JSONResponse { roomVer := gomatrixserverlib.RoomVersionV1 body := request.Content() - event, err := gomatrixserverlib.NewEventFromTrustedJSON(body, false, roomVer) + // roomVer is hardcoded to v1 so we know we won't panic on Must + event, err := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventFromTrustedJSON(body, false) switch err.(type) { case gomatrixserverlib.BadJSONError: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } case nil: default: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into an invite v1 request. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into an invite v1 request. " + err.Error()), } } - var strippedState []gomatrixserverlib.InviteV2StrippedState - if err := json.Unmarshal(event.Unsigned(), &strippedState); err != nil { + var strippedState []gomatrixserverlib.InviteStrippedState + if jsonErr := json.Unmarshal(event.Unsigned(), &strippedState); jsonErr != nil { // just warn, they may not have added any. util.GetLogger(httpReq.Context()).Warnf("failed to extract stripped state from invite event") } - return processInvite( - httpReq.Context(), false, event, roomVer, strippedState, roomID, eventID, cfg, rsAPI, keys, - ) -} -func processInvite( - ctx context.Context, - isInviteV2 bool, - event *gomatrixserverlib.Event, - roomVer gomatrixserverlib.RoomVersion, - strippedState []gomatrixserverlib.InviteV2StrippedState, - roomID string, - eventID string, - cfg *config.FederationAPI, - rsAPI api.FederationRoomserverAPI, - keys gomatrixserverlib.JSONVerifier, -) util.JSONResponse { - - // Check that we can accept invites for this room version. - if _, err := roomserverVersion.SupportedRoomVersion(roomVer); err != nil { + if event.StateKey() == nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion( - fmt.Sprintf("Room version %q is not supported by this server.", roomVer), - ), + JSON: spec.BadJSON("The invite event has no state key"), } } - // Check that the room ID is correct. - if event.RoomID() != roomID { + invitedUser, err := spec.NewUserID(*event.StateKey(), true) + if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The room ID in the request path must match the room ID in the invite event JSON"), + JSON: spec.InvalidParam("The user ID is invalid"), } } - - // Check that the event ID is correct. - if event.EventID() != eventID { + if !cfg.Matrix.IsLocalServerName(invitedUser.Domain()) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event ID in the request path must match the event ID in the invite event JSON"), + JSON: spec.InvalidParam("The invited user domain does not belong to this server"), } } - if event.StateKey() == nil { + if event.EventID() != eventID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The invite event has no state key"), + JSON: spec.BadJSON("The event ID in the request path must match the event ID in the invite event JSON"), } } - _, domain, err := cfg.Matrix.SplitLocalID('@', *event.StateKey()) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)), - } + input := gomatrixserverlib.HandleInviteInput{ + RoomVersion: roomVer, + RoomID: roomID, + InvitedUser: *invitedUser, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + RoomQuerier: rsAPI, + MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, + StateQuerier: rsAPI.StateQuerier(), + InviteEvent: event, + StrippedState: strippedState, + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, + } + event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) + if jsonErr != nil { + return *jsonErr + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: fclient.RespInvite{Event: event.JSON()}, } +} - // Check that the event is signed by the server sending the request. - redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event JSON could not be redacted"), +func handleInvite(ctx context.Context, input gomatrixserverlib.HandleInviteInput, rsAPI api.FederationRoomserverAPI) (gomatrixserverlib.PDU, *util.JSONResponse) { + inviteEvent, err := gomatrixserverlib.HandleInvite(ctx, input) + switch e := err.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(ctx).WithError(err) + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + case spec.MatrixError: + util.GetLogger(ctx).WithError(err) + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorUnsupportedRoomVersion: + fallthrough // http.StatusBadRequest + case spec.ErrorBadJSON: + code = http.StatusBadRequest } - } - _, serverName, err := gomatrixserverlib.SplitID('@', event.Sender()) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event JSON contains an invalid sender"), + + return nil, &util.JSONResponse{ + Code: code, + JSON: e, } - } - verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, - Message: redacted, - AtTS: event.OriginServerTS(), - StrictValidityChecking: true, - }} - verifyResults, err := keys.VerifyJSONs(ctx, verifyRequests) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("keys.VerifyJSONs failed") - return jsonerror.InternalServerError() - } - if verifyResults[0].Error != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The invite must be signed by the server it originated on"), + default: + util.GetLogger(ctx).WithError(err) + return nil, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("unknown error"), } } - // Sign the event so that other servers will know that we have received the invite. - signedEvent := event.Sign( - string(domain), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, - ) - - // Add the invite event to the roomserver. - inviteEvent := signedEvent.Headered(roomVer) - request := &api.PerformInviteRequest{ - Event: inviteEvent, - InviteRoomState: strippedState, - RoomVersion: inviteEvent.RoomVersion, - SendAsServer: string(api.DoNotSendToOtherServers), - TransactionID: nil, - } - response := &api.PerformInviteResponse{} - if err := rsAPI.PerformInvite(ctx, request, response); err != nil { - util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") - return util.JSONResponse{ + headeredInvite := &types.HeaderedEvent{PDU: inviteEvent} + if err = rsAPI.HandleInvite(ctx, headeredInvite); err != nil { + util.GetLogger(ctx).WithError(err).Error("HandleInvite failed") + return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), - } - } - if response.Error != nil { - return response.Error.JSONResponse() - } - // Return the signed event to the originating server, it should then tell - // the other servers in the room that we have been invited. - if isInviteV2 { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: fclient.RespInviteV2{Event: signedEvent.JSON()}, - } - } else { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: fclient.RespInvite{Event: signedEvent.JSON()}, + JSON: spec.InternalServerError{}, } } + return inviteEvent, nil } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 1476f903f8..bfa1ba8b8d 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -15,7 +15,7 @@ package routing import ( - "encoding/json" + "context" "fmt" "net/http" "sort" @@ -23,163 +23,165 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" ) // MakeJoin implements the /make_join API func MakeJoin( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, - roomID, userID string, + roomID spec.RoomID, userID spec.UserID, remoteVersions []gomatrixserverlib.RoomVersion, ) util.JSONResponse { - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String()) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("failed obtaining room version") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError{}, } } - // Check that the room that the remote side is trying to join is actually - // one of the room versions that they listed in their supported ?ver= in - // the make_join URL. - // https://matrix.org/docs/spec/server_server/r0.1.3#get-matrix-federation-v1-make-join-roomid-userid - remoteSupportsVersion := false - for _, v := range remoteVersions { - if v == verRes.RoomVersion { - remoteSupportsVersion = true - break - } + req := api.QueryServerJoinedToRoomRequest{ + ServerName: request.Destination(), + RoomID: roomID.String(), } - // If it isn't, stop trying to join the room. - if !remoteSupportsVersion { + res := api.QueryServerJoinedToRoomResponse{} + if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.IncompatibleRoomVersion(verRes.RoomVersion), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Invalid UserID"), + createJoinTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { + identity, signErr := cfg.Matrix.SigningIdentityFor(request.Destination()) + if signErr != nil { + util.GetLogger(httpReq.Context()).WithError(signErr).Errorf("obtaining signing identity for %s failed", request.Destination()) + return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) } - } - if domain != request.Origin() { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The join must be sent by the server of the user"), - } - } - // Check if we think we are still joined to the room - inRoomReq := &api.QueryServerJoinedToRoomRequest{ - ServerName: cfg.Matrix.ServerName, - RoomID: roomID, - } - inRoomRes := &api.QueryServerJoinedToRoomResponse{} - if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), inRoomReq, inRoomRes); err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") - return jsonerror.InternalServerError() - } - if !inRoomRes.RoomExists { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("Room ID %q was not found on this server", roomID)), + queryRes := api.QueryLatestEventsAndStateResponse{ + RoomVersion: roomVersion, } - } - if !inRoomRes.IsInRoom { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("Room ID %q has no remaining users on this server", roomID)), + event, signErr := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) + switch e := signErr.(type) { + case nil: + case eventutil.ErrRoomNoExists: + util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed") + return nil, nil, spec.NotFound("Room does not exist") + case gomatrixserverlib.BadJSONError: + util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed") + return nil, nil, spec.BadJSON(e.Error()) + default: + util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed") + return nil, nil, spec.InternalServerError{} } - } - // Check if the restricted join is allowed. If the room doesn't - // support restricted joins then this is effectively a no-op. - res, authorisedVia, err := checkRestrictedJoin(httpReq, rsAPI, verRes.RoomVersion, roomID, userID) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("checkRestrictedJoin failed") - return jsonerror.InternalServerError() - } else if res != nil { - return *res + stateEvents := make([]gomatrixserverlib.PDU, len(queryRes.StateEvents)) + for i, stateEvent := range queryRes.StateEvents { + stateEvents[i] = stateEvent.PDU + } + return event, stateEvents, nil } - // Try building an event for the server - builder := gomatrixserverlib.EventBuilder{ - Sender: userID, - RoomID: roomID, - Type: "m.room.member", - StateKey: &userID, - } - content := gomatrixserverlib.MemberContent{ - Membership: gomatrixserverlib.Join, - AuthorisedVia: authorisedVia, - } - if err = builder.SetContent(content); err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("builder.SetContent failed") - return jsonerror.InternalServerError() + roomQuerier := api.JoinRoomQuerier{ + Roomserver: rsAPI, } - identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID) if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound( - fmt.Sprintf("Server name %q does not exist", request.Destination()), - ), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } - queryRes := api.QueryLatestEventsAndStateResponse{ - RoomVersion: verRes.RoomVersion, + if senderID == "" { + senderID = spec.SenderID(userID.String()) + } + + input := gomatrixserverlib.HandleMakeJoinInput{ + Context: httpReq.Context(), + UserID: userID, + SenderID: senderID, + RoomID: roomID, + RoomVersion: roomVersion, + RemoteVersions: remoteVersions, + RequestOrigin: request.Origin(), + LocalServerName: cfg.Matrix.ServerName, + LocalServerInRoom: res.RoomExists && res.IsInRoom, + RoomQuerier: &roomQuerier, + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, + BuildEventTemplate: createJoinTemplate, } - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) - if err == eventutil.ErrRoomNoExists { + response, internalErr := gomatrixserverlib.HandleMakeJoin(input) + switch e := internalErr.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_join request") return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + case spec.MatrixError: + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_join request") + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorNotFound: + code = http.StatusNotFound + case spec.ErrorUnableToAuthoriseJoin: + fallthrough // http.StatusBadRequest + case spec.ErrorBadJSON: + code = http.StatusBadRequest + } + + return util.JSONResponse{ + Code: code, + JSON: e, } - } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { + case spec.IncompatibleRoomVersionError: + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_join request") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + JSON: e, + } + default: + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_join request") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("unknown error"), } - } else if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") - return jsonerror.InternalServerError() - } - - // Check that the join is allowed or not - stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) - for i := range queryRes.StateEvents { - stateEvents[i] = queryRes.StateEvents[i].Event } - provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(event.Event, &provider); err != nil { + if response == nil { + util.GetLogger(httpReq.Context()).Error("gmsl.HandleMakeJoin returned invalid response") return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(err.Error()), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } return util.JSONResponse{ Code: http.StatusOK, JSON: map[string]interface{}{ - "event": builder, - "room_version": verRes.RoomVersion, + "event": response.JoinTemplateEvent, + "room_version": response.RoomVersion, }, } } @@ -190,238 +192,146 @@ func MakeJoin( // nolint:gocyclo func SendJoin( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, - roomID, eventID string, + roomID spec.RoomID, + eventID string, ) util.JSONResponse { - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String()) + if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), - } - } - - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(request.Content(), verRes.RoomVersion) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON: " + err.Error()), - } - } - - // Check that a state key is provided. - if event.StateKey() == nil || event.StateKeyEquals("") { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("No state key was provided in the join event."), - } - } - if !event.StateKeyEquals(event.Sender()) { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Event state key must match the event sender."), - } + JSON: spec.InternalServerError{}, + } + } + + input := gomatrixserverlib.HandleSendJoinInput{ + Context: httpReq.Context(), + RoomID: roomID, + EventID: eventID, + JoinEvent: request.Content(), + RoomVersion: roomVersion, + RequestOrigin: request.Origin(), + LocalServerName: cfg.Matrix.ServerName, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, + StoreSenderIDFromPublicID: func(ctx context.Context, senderID spec.SenderID, userIDRaw string, roomID spec.RoomID) error { + userID, userErr := spec.NewUserID(userIDRaw, true) + if userErr != nil { + return userErr + } + return rsAPI.StoreUserRoomPublicKey(ctx, senderID, *userID, roomID) + }, } - - // Check that the sender belongs to the server that is sending us - // the request. By this point we've already asserted that the sender - // and the state key are equal so we don't need to check both. - var serverName gomatrixserverlib.ServerName - if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The sender of the join is invalid"), - } - } else if serverName != request.Origin() { + response, joinErr := gomatrixserverlib.HandleSendJoin(input) + switch e := joinErr.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(httpReq.Context()).WithError(joinErr) return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The sender does not match the server that originated the request"), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + case spec.MatrixError: + util.GetLogger(httpReq.Context()).WithError(joinErr) + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorNotFound: + code = http.StatusNotFound + case spec.ErrorUnsupportedRoomVersion: + code = http.StatusInternalServerError + case spec.ErrorBadJSON: + code = http.StatusBadRequest } - } - // Check that the room ID is correct. - if event.RoomID() != roomID { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON( - fmt.Sprintf( - "The room ID in the request path (%q) must match the room ID in the join event JSON (%q)", - roomID, event.RoomID(), - ), - ), + Code: code, + JSON: e, } - } - - // Check that the event ID is correct. - if event.EventID() != eventID { + default: + util.GetLogger(httpReq.Context()).WithError(joinErr) return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON( - fmt.Sprintf( - "The event ID in the request path (%q) must match the event ID in the join event JSON (%q)", - eventID, event.EventID(), - ), - ), + JSON: spec.Unknown("unknown error"), } } - // Check that this is in fact a join event - membership, err := event.Membership() - if err != nil { + if response == nil { + util.GetLogger(httpReq.Context()).Error("gmsl.HandleMakeJoin returned invalid response") return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing content.membership key"), - } - } - if membership != gomatrixserverlib.Join { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("membership must be 'join'"), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } - } - // Check that the event is signed by the server sending the request. - redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) - if err != nil { - logrus.WithError(err).Errorf("XXX: join.go") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event JSON could not be redacted"), - } - } - verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, - Message: redacted, - AtTS: event.OriginServerTS(), - StrictValidityChecking: true, - }} - verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed") - return jsonerror.InternalServerError() - } - if verifyResults[0].Error != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Signature check failed: " + verifyResults[0].Error.Error()), - } } // Fetch the state and auth chain. We do this before we send the events // on, in case this fails. var stateAndAuthChainResponse api.QueryStateAndAuthChainResponse err = rsAPI.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{ - PrevEventIDs: event.PrevEventIDs(), - AuthEventIDs: event.AuthEventIDs(), - RoomID: roomID, + PrevEventIDs: response.JoinEvent.PrevEventIDs(), + AuthEventIDs: response.JoinEvent.AuthEventIDs(), + RoomID: roomID.String(), ResolveState: true, }, &stateAndAuthChainResponse) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryStateAndAuthChain failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !stateAndAuthChainResponse.RoomExists { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + JSON: spec.NotFound("Room does not exist"), } } if !stateAndAuthChainResponse.StateKnown { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("State not known"), - } - } - - // Check if the user is already in the room. If they're already in then - // there isn't much point in sending another join event into the room. - // Also check to see if they are banned: if they are then we reject them. - alreadyJoined := false - isBanned := false - for _, se := range stateAndAuthChainResponse.StateEvents { - if !se.StateKeyEquals(*event.StateKey()) { - continue - } - if membership, merr := se.Membership(); merr == nil { - alreadyJoined = (membership == gomatrixserverlib.Join) - isBanned = (membership == gomatrixserverlib.Ban) - break - } - } - - if isBanned { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("user is banned"), + JSON: spec.Forbidden("State not known"), } } - // If the membership content contains a user ID for a server that is not - // ours then we should kick it back. - var memberContent gomatrixserverlib.MemberContent - if err := json.Unmarshal(event.Content(), &memberContent); err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), - } - } - if memberContent.AuthorisedVia != "" { - _, domain, err := gomatrixserverlib.SplitID('@', memberContent.AuthorisedVia) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("The authorising username %q is invalid.", memberContent.AuthorisedVia)), - } - } - if domain != cfg.Matrix.ServerName { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("The authorising username %q does not belong to this server.", memberContent.AuthorisedVia)), - } - } - } - - // Sign the membership event. This is required for restricted joins to work - // in the case that the authorised via user is one of our own users. It also - // doesn't hurt to do it even if it isn't a restricted join. - signed := event.Sign( - string(cfg.Matrix.ServerName), - cfg.Matrix.KeyID, - cfg.Matrix.PrivateKey, - ) - // Send the events to the room server. // We are responsible for notifying other servers that the user has joined // the room, so set SendAsServer to cfg.Matrix.ServerName - if !alreadyJoined { - var response api.InputRoomEventsResponse - if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ + if !response.AlreadyJoined { + var rsResponse api.InputRoomEventsResponse + rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, - Event: signed.Headered(stateAndAuthChainResponse.RoomVersion), + Event: &types.HeaderedEvent{PDU: response.JoinEvent}, SendAsServer: string(cfg.Matrix.ServerName), TransactionID: nil, }, }, - }, &response); err != nil { - return jsonerror.InternalAPIError(httpReq.Context(), err) - } - if response.ErrMsg != "" { - util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed") - if response.NotAllowed { + }, &rsResponse) + if rsResponse.ErrMsg != "" { + util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, rsResponse.ErrMsg).Error("SendEvents failed") + if rsResponse.NotAllowed { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Forbidden(response.ErrMsg), + JSON: spec.Forbidden(rsResponse.ErrMsg), } } - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } @@ -435,81 +345,15 @@ func SendJoin( return util.JSONResponse{ Code: http.StatusOK, JSON: fclient.RespSendJoin{ - StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents), - AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents), + StateEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents), + AuthEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents), Origin: cfg.Matrix.ServerName, - Event: signed.JSON(), + Event: response.JoinEvent.JSON(), }, } } -// checkRestrictedJoin finds out whether or not we can assist in processing -// a restricted room join. If the room version does not support restricted -// joins then this function returns with no side effects. This returns three -// values: -// - an optional JSON response body (i.e. M_UNABLE_TO_AUTHORISE_JOIN) which -// should always be sent back to the client if one is specified -// - a user ID of an authorising user, typically a user that has power to -// issue invites in the room, if one has been found -// - an error if there was a problem finding out if this was allowable, -// like if the room version isn't known or a problem happened talking to -// the roomserver -func checkRestrictedJoin( - httpReq *http.Request, - rsAPI api.FederationRoomserverAPI, - roomVersion gomatrixserverlib.RoomVersion, - roomID, userID string, -) (*util.JSONResponse, string, error) { - if allowRestricted, err := roomVersion.MayAllowRestrictedJoinsInEventAuth(); err != nil { - return nil, "", err - } else if !allowRestricted { - return nil, "", nil - } - req := &api.QueryRestrictedJoinAllowedRequest{ - RoomID: roomID, - UserID: userID, - } - res := &api.QueryRestrictedJoinAllowedResponse{} - if err := rsAPI.QueryRestrictedJoinAllowed(httpReq.Context(), req, res); err != nil { - return nil, "", err - } - - switch { - case !res.Restricted: - // The join rules for the room don't restrict membership. - return nil, "", nil - - case !res.Resident: - // The join rules restrict membership but our server isn't currently - // joined to all of the allowed rooms, so we can't actually decide - // whether or not to allow the user to join. This error code should - // tell the joining server to try joining via another resident server - // instead. - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.UnableToAuthoriseJoin("This server cannot authorise the join."), - }, "", nil - - case !res.Allowed: - // The join rules restrict membership, our server is in the relevant - // rooms and the user wasn't joined to join any of the allowed rooms - // and therefore can't join this room. - return &util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You are not joined to any matching rooms."), - }, "", nil - - default: - // The join rules restrict membership, our server is in the relevant - // rooms and the user was allowed to join because they belong to one - // of the allowed rooms. We now need to pick one of our own local users - // from within the room to use as the authorising user ID, so that it - // can be referred to from within the membership content. - return nil, res.AuthorisedVia, nil - } -} - -type eventsByDepth []*gomatrixserverlib.HeaderedEvent +type eventsByDepth []*types.HeaderedEvent func (e eventsByDepth) Len() int { return len(e) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index db768591f7..3d8ff2deab 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -20,12 +20,12 @@ import ( "time" clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "golang.org/x/crypto/ed25519" @@ -38,14 +38,14 @@ type queryKeysRequest struct { // QueryDeviceKeys returns device keys for users on this server. // https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-query func QueryDeviceKeys( - httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.FederationKeyAPI, thisServer gomatrixserverlib.ServerName, + httpReq *http.Request, request *fclient.FederationRequest, keyAPI api.FederationKeyAPI, thisServer spec.ServerName, ) util.JSONResponse { var qkr queryKeysRequest err := json.Unmarshal(request.Content(), &qkr) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } // make sure we only query users on our domain @@ -62,14 +62,15 @@ func QueryDeviceKeys( } var queryRes api.QueryKeysResponse - if err := keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ + keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ UserToDevices: qkr.DeviceKeys, - }, &queryRes); err != nil { - return jsonerror.InternalAPIError(httpReq.Context(), err) - } + }, &queryRes) if queryRes.Error != nil { util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: 200, @@ -92,14 +93,14 @@ type claimOTKsRequest struct { // ClaimOneTimeKeys claims OTKs for users on this server. // https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-claim func ClaimOneTimeKeys( - httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.FederationKeyAPI, thisServer gomatrixserverlib.ServerName, + httpReq *http.Request, request *fclient.FederationRequest, keyAPI api.FederationKeyAPI, thisServer spec.ServerName, ) util.JSONResponse { var cor claimOTKsRequest err := json.Unmarshal(request.Content(), &cor) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } // make sure we only claim users on our domain @@ -116,14 +117,15 @@ func ClaimOneTimeKeys( } var claimRes api.PerformClaimKeysResponse - if err := keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ + keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ OneTimeKeys: cor.OneTimeKeys, - }, &claimRes); err != nil { - return jsonerror.InternalAPIError(httpReq.Context(), err) - } + }, &claimRes) if claimRes.Error != nil { util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: 200, @@ -135,7 +137,7 @@ func ClaimOneTimeKeys( // LocalKeys returns the local keys for the server. // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys -func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse { +func LocalKeys(cfg *config.FederationAPI, serverName spec.ServerName) util.JSONResponse { keys, err := localKeys(cfg, serverName) if err != nil { return util.MessageResponse(http.StatusNotFound, err.Error()) @@ -143,7 +145,7 @@ func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam return util.JSONResponse{Code: http.StatusOK, JSON: keys} } -func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { +func localKeys(cfg *config.FederationAPI, serverName spec.ServerName) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys var identity *fclient.SigningIdentity var err error @@ -153,10 +155,10 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam } publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) keys.ServerName = cfg.Matrix.ServerName - keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) + keys.ValidUntilTS = spec.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ cfg.Matrix.KeyID: { - Key: gomatrixserverlib.Base64Bytes(publicKey), + Key: spec.Base64Bytes(publicKey), }, } keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} @@ -174,10 +176,10 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam } publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey) keys.ServerName = virtualHost.ServerName - keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod)) + keys.ValidUntilTS = spec.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod)) keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ virtualHost.KeyID: { - Key: gomatrixserverlib.Base64Bytes(publicKey), + Key: spec.Base64Bytes(publicKey), }, } // TODO: Virtual hosts probably want to be able to specify old signing @@ -200,11 +202,11 @@ func NotaryKeys( fsAPI federationAPI.FederationInternalAPI, req *gomatrixserverlib.PublicKeyNotaryLookupRequest, ) util.JSONResponse { - serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal + serverName := spec.ServerName(httpReq.Host) // TODO: this is not ideal if !cfg.Matrix.IsLocalServerName(serverName) { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Server name not known"), + JSON: spec.NotFound("Server name not known"), } } @@ -247,7 +249,10 @@ func NotaryKeys( j, err := json.Marshal(keys) if err != nil { logrus.WithError(err).Errorf("Failed to marshal %q response", serverName) - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } js, err := gomatrixserverlib.SignJSON( @@ -255,7 +260,10 @@ func NotaryKeys( ) if err != nil { logrus.WithError(err).Errorf("Failed to sign %q response", serverName) - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } response.ServerKeys = append(response.ServerKeys, js) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index f1e9f49ba0..5c8dd00f3e 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -17,11 +17,13 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -29,101 +31,131 @@ import ( // MakeLeave implements the /make_leave API func MakeLeave( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, - roomID, userID string, + roomID spec.RoomID, userID spec.UserID, ) util.JSONResponse { - _, domain, err := gomatrixserverlib.SplitID('@', userID) + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String()) if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("failed obtaining room version") return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Invalid UserID"), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } - if domain != request.Origin() { + + req := api.QueryServerJoinedToRoomRequest{ + ServerName: request.Destination(), + RoomID: roomID.String(), + } + res := api.QueryServerJoinedToRoomResponse{} + if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The leave must be sent by the server of the user"), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } - // Try building an event for the server - builder := gomatrixserverlib.EventBuilder{ - Sender: userID, - RoomID: roomID, - Type: "m.room.member", - StateKey: &userID, - } - err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Leave}) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("builder.SetContent failed") - return jsonerror.InternalServerError() + createLeaveTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { + identity, signErr := cfg.Matrix.SigningIdentityFor(request.Destination()) + if signErr != nil { + util.GetLogger(httpReq.Context()).WithError(signErr).Errorf("obtaining signing identity for %s failed", request.Destination()) + return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) + } + + queryRes := api.QueryLatestEventsAndStateResponse{} + event, buildErr := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) + switch e := buildErr.(type) { + case nil: + case eventutil.ErrRoomNoExists: + util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed") + return nil, nil, spec.NotFound("Room does not exist") + case gomatrixserverlib.BadJSONError: + util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed") + return nil, nil, spec.BadJSON(e.Error()) + default: + util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed") + return nil, nil, spec.InternalServerError{} + } + + stateEvents := make([]gomatrixserverlib.PDU, len(queryRes.StateEvents)) + for i, stateEvent := range queryRes.StateEvents { + stateEvents[i] = stateEvent.PDU + } + return event, stateEvents, nil } - identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID) if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound( - fmt.Sprintf("Server name %q does not exist", request.Destination()), - ), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } - var queryRes api.QueryLatestEventsAndStateResponse - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) - if err == eventutil.ErrRoomNoExists { + input := gomatrixserverlib.HandleMakeLeaveInput{ + UserID: userID, + SenderID: senderID, + RoomID: roomID, + RoomVersion: roomVersion, + RequestOrigin: request.Origin(), + LocalServerName: cfg.Matrix.ServerName, + LocalServerInRoom: res.RoomExists && res.IsInRoom, + BuildEventTemplate: createLeaveTemplate, + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, + } + + response, internalErr := gomatrixserverlib.HandleMakeLeave(input) + switch e := internalErr.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_leave request") return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room does not exist"), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } - } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + case spec.MatrixError: + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_leave request") + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorNotFound: + code = http.StatusNotFound + case spec.ErrorBadJSON: + code = http.StatusBadRequest } - } else if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") - return jsonerror.InternalServerError() - } - // If the user has already left then just return their last leave - // event. This means that /send_leave will be a no-op, which helps - // to reject invites multiple times - hopefully. - for _, state := range queryRes.StateEvents { - if !state.StateKeyEquals(userID) { - continue - } - if mem, merr := state.Membership(); merr == nil && mem == gomatrixserverlib.Leave { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: map[string]interface{}{ - "room_version": event.RoomVersion, - "event": state, - }, - } + return util.JSONResponse{ + Code: code, + JSON: e, + } + default: + util.GetLogger(httpReq.Context()).WithError(internalErr).Error("failed to handle make_leave request") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("unknown error"), } } - // Check that the leave is allowed or not - stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) - for i := range queryRes.StateEvents { - stateEvents[i] = queryRes.StateEvents[i].Event - } - provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(event.Event, &provider); err != nil { + if response == nil { + util.GetLogger(httpReq.Context()).Error("gmsl.HandleMakeLeave returned invalid response") return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(err.Error()), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } } return util.JSONResponse{ Code: http.StatusOK, JSON: map[string]interface{}{ - "room_version": event.RoomVersion, - "event": builder, + "event": response.LeaveTemplateEvent, + "room_version": response.RoomVersion, }, } } @@ -132,34 +164,43 @@ func MakeLeave( // nolint:gocyclo func SendLeave( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, roomID, eventID string, ) util.JSONResponse { - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID) + if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(err.Error()), + JSON: spec.UnsupportedRoomVersion(err.Error()), + } + } + + verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.UnsupportedRoomVersion( + fmt.Sprintf("QueryRoomVersionForRoom returned unknown version: %s", roomVersion), + ), } } // Decode the event JSON from the request. - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(request.Content(), verRes.RoomVersion) + event, err := verImpl.NewEventFromUntrustedJSON(request.Content()) switch err.(type) { case gomatrixserverlib.BadJSONError: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } case nil: default: return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } @@ -167,7 +208,7 @@ func SendLeave( if event.RoomID() != roomID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The room ID in the request path must match the room ID in the leave event JSON"), + JSON: spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON"), } } @@ -175,36 +216,43 @@ func SendLeave( if event.EventID() != eventID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event ID in the request path must match the event ID in the leave event JSON"), + JSON: spec.BadJSON("The event ID in the request path must match the event ID in the leave event JSON"), } } if event.StateKey() == nil || event.StateKeyEquals("") { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("No state key was provided in the leave event."), + JSON: spec.BadJSON("No state key was provided in the leave event."), } } - if !event.StateKeyEquals(event.Sender()) { + if !event.StateKeyEquals(string(event.SenderID())) { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Event state key must match the event sender."), + JSON: spec.BadJSON("Event state key must match the event sender."), } } // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - var serverName gomatrixserverlib.ServerName - if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("Room ID is invalid."), + } + } + sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, event.SenderID()) + if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The sender of the join is invalid"), + JSON: spec.Forbidden("The sender of the join is invalid"), } - } else if serverName != request.Origin() { + } else if sender.Domain() != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The sender does not match the server that originated the request"), + JSON: spec.Forbidden("The sender does not match the server that originated the request"), } } @@ -213,7 +261,7 @@ func SendLeave( RoomID: roomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{ { - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: *event.StateKey(), }, }, @@ -222,7 +270,10 @@ func SendLeave( err = rsAPI.QueryLatestEventsAndState(httpReq.Context(), queryReq, queryRes) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryLatestEventsAndState failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // The room doesn't exist or we weren't ever joined to it. Might as well // no-op here. @@ -242,7 +293,7 @@ func SendLeave( // We are/were joined/invited/banned or something. Check if // we can no-op here. if len(queryRes.StateEvents) == 1 { - if mem, merr := queryRes.StateEvents[0].Membership(); merr == nil && mem == gomatrixserverlib.Leave { + if mem, merr := queryRes.StateEvents[0].Membership(); merr == nil && mem == spec.Leave { return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, @@ -251,29 +302,32 @@ func SendLeave( } // Check that the event is signed by the server sending the request. - redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) + redacted, err := verImpl.RedactEventJSON(event.JSON()) if err != nil { logrus.WithError(err).Errorf("XXX: leave.go") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event JSON could not be redacted"), + JSON: spec.BadJSON("The event JSON could not be redacted"), } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, - Message: redacted, - AtTS: event.OriginServerTS(), - StrictValidityChecking: true, + ServerName: sender.Domain(), + Message: redacted, + AtTS: event.OriginServerTS(), + ValidityCheckingFunc: gomatrixserverlib.StrictValiditySignatureCheck, }} verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if verifyResults[0].Error != nil { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The leave must be signed by the server it originated on"), + JSON: spec.Forbidden("The leave must be signed by the server it originated on"), } } @@ -283,13 +337,13 @@ func SendLeave( util.GetLogger(httpReq.Context()).WithError(err).Error("event.Membership failed") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("missing content.membership key"), + JSON: spec.BadJSON("missing content.membership key"), } } - if mem != gomatrixserverlib.Leave { + if mem != spec.Leave { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The membership in the event content must be set to leave"), + JSON: spec.BadJSON("The membership in the event content must be set to leave"), } } @@ -297,28 +351,29 @@ func SendLeave( // We are responsible for notifying other servers that the user has left // the room, so set SendAsServer to cfg.Matrix.ServerName var response api.InputRoomEventsResponse - if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ + rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, - Event: event.Headered(verRes.RoomVersion), + Event: &types.HeaderedEvent{PDU: event}, SendAsServer: string(cfg.Matrix.ServerName), TransactionID: nil, }, }, - }, &response); err != nil { - return jsonerror.InternalAPIError(httpReq.Context(), err) - } + }, &response) if response.ErrMsg != "" { util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).WithField("not_allowed", response.NotAllowed).Error("producer.SendEvents failed") if response.NotAllowed { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Forbidden(response.ErrMsg), + JSON: spec.Forbidden(response.ErrMsg), } } - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index 63a32b9c45..f57d302041 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -16,10 +16,10 @@ import ( "encoding/json" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -34,7 +34,7 @@ type getMissingEventRequest struct { // Events are fetched from room DAG starting from latest_events until we reach earliest_events or the limit. func GetMissingEvents( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, ) util.JSONResponse { @@ -42,7 +42,7 @@ func GetMissingEvents( if err := json.Unmarshal(request.Content(), &gme); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } @@ -63,13 +63,16 @@ func GetMissingEvents( &eventsResponse, ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("query.QueryMissingEvents failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } eventsResponse.Events = filterEvents(eventsResponse.Events, roomID) resp := fclient.RespMissingEvents{ - Events: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(eventsResponse.Events), + Events: types.NewEventJSONsFromHeaderedEvents(eventsResponse.Events), } return util.JSONResponse{ @@ -80,8 +83,8 @@ func GetMissingEvents( // filterEvents returns only those events with matching roomID func filterEvents( - events []*gomatrixserverlib.HeaderedEvent, roomID string, -) []*gomatrixserverlib.HeaderedEvent { + events []*types.HeaderedEvent, roomID string, +) []*types.HeaderedEvent { ref := events[:0] for _, ev := range events { if ev.RoomID() == roomID { diff --git a/federationapi/routing/openid.go b/federationapi/routing/openid.go index cbc75a9a72..d28f319f57 100644 --- a/federationapi/routing/openid.go +++ b/federationapi/routing/openid.go @@ -18,8 +18,8 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -36,7 +36,7 @@ func GetOpenIDUserInfo( if len(token) == 0 { return util.JSONResponse{ Code: http.StatusUnauthorized, - JSON: jsonerror.MissingArgument("access_token is missing"), + JSON: spec.MissingParam("access_token is missing"), } } @@ -55,7 +55,7 @@ func GetOpenIDUserInfo( nowMS := time.Now().UnixNano() / int64(time.Millisecond) if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresAtMS { code = http.StatusUnauthorized - res = jsonerror.UnknownToken("Access Token unknown or expired") + res = spec.UnknownToken("Access Token unknown or expired") } return util.JSONResponse{ diff --git a/federationapi/routing/peek.go b/federationapi/routing/peek.go index 6c4d315c0c..f5003b147d 100644 --- a/federationapi/routing/peek.go +++ b/federationapi/routing/peek.go @@ -17,31 +17,30 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) // Peek implements the SS /peek API, handling inbound peeks func Peek( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, roomID, peekID string, remoteVersions []gomatrixserverlib.RoomVersion, ) util.JSONResponse { // TODO: check if we're just refreshing an existing peek by querying the federationapi - - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID) + if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError{}, } } @@ -50,7 +49,7 @@ func Peek( // the peek URL. remoteSupportsVersion := false for _, v := range remoteVersions { - if v == verRes.RoomVersion { + if v == roomVersion { remoteSupportsVersion = true break } @@ -59,7 +58,7 @@ func Peek( if !remoteSupportsVersion { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.IncompatibleRoomVersion(verRes.RoomVersion), + JSON: spec.IncompatibleRoomVersion(string(roomVersion)), } } @@ -69,7 +68,7 @@ func Peek( renewalInterval := int64(60 * 60 * 1000 * 1000) var response api.PerformInboundPeekResponse - err := rsAPI.PerformInboundPeek( + err = rsAPI.PerformInboundPeek( httpReq.Context(), &api.PerformInboundPeekRequest{ RoomID: roomID, @@ -89,10 +88,10 @@ func Peek( } respPeek := fclient.RespPeek{ - StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.StateEvents), - AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), + StateEvents: types.NewEventJSONsFromHeaderedEvents(response.StateEvents), + AuthEvents: types.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), RoomVersion: response.RoomVersion, - LatestEvent: response.LatestEvent.Unwrap(), + LatestEvent: response.LatestEvent.PDU, RenewalInterval: renewalInterval, } diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index 55641b216e..e6a488ba33 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -18,10 +18,10 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -37,7 +37,7 @@ func GetProfile( if userID == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("The request body did not contain required argument 'user_id'."), + JSON: spec.MissingParam("The request body did not contain required argument 'user_id'."), } } @@ -46,14 +46,17 @@ func GetProfile( util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)), + JSON: spec.InvalidParam(fmt.Sprintf("Domain %q does not match this server", domain)), } } profile, err := userAPI.QueryProfile(httpReq.Context(), userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryProfile failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var res interface{} @@ -71,7 +74,7 @@ func GetProfile( } default: code = http.StatusBadRequest - res = jsonerror.InvalidArgumentValue("The request body did not contain an allowed value of argument 'field'. Allowed values are either: 'avatar_url', 'displayname'.") + res = spec.InvalidParam("The request body did not contain an allowed value of argument 'field'. Allowed values are either: 'avatar_url', 'displayname'.") } } else { res = eventutil.UserProfile{ diff --git a/federationapi/routing/profile_test.go b/federationapi/routing/profile_test.go index d249fce14f..a31b206c1c 100644 --- a/federationapi/routing/profile_test.go +++ b/federationapi/routing/profile_test.go @@ -36,6 +36,8 @@ import ( "github.com/matrix-org/dendrite/test/testrig" userAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ed25519" ) @@ -69,14 +71,14 @@ func TestHandleQueryProfile(t *testing.T) { if !ok { panic("This is a programming error.") } - routing.Setup(routers, cfg, nil, r, keyRing, &fedClient, &userapi, &cfg.MSCs, nil, nil, caching.DisableMetrics) + routing.Setup(routers, cfg, nil, r, keyRing, &fedClient, &userapi, &cfg.MSCs, nil, caching.DisableMetrics) handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) keyID := signing.KeyID pk := sk.Public().(ed25519.PublicKey) - serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin))) + serverName := spec.ServerName(hex.EncodeToString(pk)) + req := fclient.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin))) type queryContent struct{} content := queryContent{} err := req.SetContent(content) diff --git a/federationapi/routing/publicrooms.go b/federationapi/routing/publicrooms.go index 7c5d6a02e1..213d1631a5 100644 --- a/federationapi/routing/publicrooms.go +++ b/federationapi/routing/publicrooms.go @@ -8,10 +8,10 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" ) @@ -39,7 +39,10 @@ func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.FederationRoomser } response, err := publicRooms(req.Context(), request, rsAPI) if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ Code: http.StatusOK, @@ -106,8 +109,10 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO // In that case, we want to assign 0 so we ignore the error if err != nil && len(httpReq.FormValue("limit")) > 0 { util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") - reqErr := jsonerror.InternalServerError() - return &reqErr + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } request.Limit = int16(limit) request.Since = httpReq.FormValue("since") @@ -118,7 +123,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO return &util.JSONResponse{ Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), + JSON: spec.NotFound("Bad method"), } } @@ -126,11 +131,11 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.FederationRoomserverAPI) ([]fclient.PublicRoom, error) { avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} - canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} + canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCanonicalAlias, StateKey: ""} topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""} guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""} - visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} - joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} + visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomHistoryVisibility, StateKey: ""} + joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""} var stateRes roomserverAPI.QueryBulkStateContentResponse err := rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ @@ -138,7 +143,7 @@ func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.Fede AllowWildcards: true, StateTuples: []gomatrixserverlib.StateKeyTuple{ nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple, - {EventType: gomatrixserverlib.MRoomMember, StateKey: "*"}, + {EventType: spec.MRoomMember, StateKey: "*"}, }, }, &stateRes) if err != nil { @@ -154,7 +159,7 @@ func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.Fede joinCount := 0 var joinRule, guestAccess string for tuple, contentVal := range data { - if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" { + if tuple.EventType == spec.MRoomMember && contentVal == "join" { joinCount++ continue } @@ -178,7 +183,7 @@ func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.Fede guestAccess = contentVal } } - if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" { + if joinRule == spec.Public && guestAccess == "can_join" { pub.GuestCanJoin = true } pub.JoinedMembersCount = joinCount diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 6b1c371ecd..2e845f32ca 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -18,20 +18,20 @@ import ( "fmt" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) // RoomAliasToID converts the queried alias into a room ID and returns it func RoomAliasToID( httpReq *http.Request, - federation federationAPI.FederationClient, + federation fclient.FederationClient, cfg *config.FederationAPI, rsAPI roomserverAPI.FederationRoomserverAPI, senderAPI federationAPI.FederationInternalAPI, @@ -40,14 +40,14 @@ func RoomAliasToID( if roomAlias == "" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Must supply room alias parameter."), + JSON: spec.BadJSON("Must supply room alias parameter."), } } _, domain, err := gomatrixserverlib.SplitID('#', roomAlias) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Room alias must be in the form '#localpart:domain'"), + JSON: spec.BadJSON("Room alias must be in the form '#localpart:domain'"), } } @@ -61,7 +61,10 @@ func RoomAliasToID( queryRes := &roomserverAPI.GetRoomIDForAliasResponse{} if err = rsAPI.GetRoomIDForAlias(httpReq.Context(), queryReq, queryRes); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if queryRes.RoomID != "" { @@ -69,7 +72,10 @@ func RoomAliasToID( var serverQueryRes federationAPI.QueryJoinedHostServerNamesInRoomResponse if err = senderAPI.QueryJoinedHostServerNamesInRoom(httpReq.Context(), &serverQueryReq, &serverQueryRes); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("senderAPI.QueryJoinedHostServerNamesInRoom failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } resp = fclient.RespDirectory{ @@ -80,7 +86,7 @@ func RoomAliasToID( // If no alias was found, return an error return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("Room alias %s not found", roomAlias)), + JSON: spec.NotFound(fmt.Sprintf("Room alias %s not found", roomAlias)), } } } else { @@ -91,14 +97,17 @@ func RoomAliasToID( if x.Code == http.StatusNotFound { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room alias not found"), + JSON: spec.NotFound("Room alias not found"), } } } // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. util.GetLogger(httpReq.Context()).WithError(err).Error("federation.LookupRoomAlias failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } diff --git a/federationapi/routing/query_test.go b/federationapi/routing/query_test.go index 807e7b2f2d..bb14ab031c 100644 --- a/federationapi/routing/query_test.go +++ b/federationapi/routing/query_test.go @@ -25,7 +25,6 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" fedAPI "github.com/matrix-org/dendrite/federationapi" - fedclient "github.com/matrix-org/dendrite/federationapi/api" fedInternal "github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/routing" "github.com/matrix-org/dendrite/internal/caching" @@ -36,15 +35,16 @@ import ( "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ed25519" ) type fakeFedClient struct { - fedclient.FederationClient + fclient.FederationClient } -func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res fclient.RespDirectory, err error) { +func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s spec.ServerName, roomAlias string) (res fclient.RespDirectory, err error) { return } @@ -69,14 +69,14 @@ func TestHandleQueryDirectory(t *testing.T) { if !ok { panic("This is a programming error.") } - routing.Setup(routers, cfg, nil, r, keyRing, &fedClient, &userapi, &cfg.MSCs, nil, nil, caching.DisableMetrics) + routing.Setup(routers, cfg, nil, r, keyRing, &fedClient, &userapi, &cfg.MSCs, nil, caching.DisableMetrics) handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) keyID := signing.KeyID pk := sk.Public().(ed25519.PublicKey) - serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server")) + serverName := spec.ServerName(hex.EncodeToString(pk)) + req := fclient.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server")) type queryContent struct{} content := queryContent{} err := req.SetContent(content) diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index a1f943e776..8865022ff3 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -23,8 +23,6 @@ import ( "github.com/getsentry/sentry-go" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - federationAPI "github.com/matrix-org/dendrite/federationapi/api" fedInternal "github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" @@ -34,6 +32,8 @@ import ( "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" @@ -59,10 +59,9 @@ func Setup( rsAPI roomserverAPI.FederationRoomserverAPI, fsAPI *fedInternal.FederationInternalAPI, keys gomatrixserverlib.JSONVerifier, - federation federationAPI.FederationClient, + federation fclient.FederationClient, userAPI userapi.FederationUserAPI, mscCfg *config.MSCs, - servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, enableMetrics bool, ) { fedMux := routers.Federation @@ -85,7 +84,7 @@ func Setup( } localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { - return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host)) + return LocalKeys(cfg, spec.ServerName(req.Host)) }) notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse { @@ -94,11 +93,11 @@ func Setup( return util.ErrorResponse(err) } var pkReq *gomatrixserverlib.PublicKeyNotaryLookupRequest - serverName := gomatrixserverlib.ServerName(vars["serverName"]) + serverName := spec.ServerName(vars["serverName"]) keyID := gomatrixserverlib.KeyID(vars["keyID"]) if serverName != "" && keyID != "" { pkReq = &gomatrixserverlib.PublicKeyNotaryLookupRequest{ - ServerKeys: map[gomatrixserverlib.ServerName]map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria{ + ServerKeys: map[spec.ServerName]map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria{ serverName: { keyID: gomatrixserverlib.PublicKeyNotaryQueryCriteria{}, }, @@ -136,25 +135,33 @@ func Setup( mu := internal.NewMutexByRoom() v1fedmux.Handle("/send/{txnID}", MakeFedAPI( "federation_send", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), - cfg, rsAPI, userAPI, keys, federation, mu, servers, producer, + cfg, rsAPI, userAPI, keys, federation, mu, producer, ) }, )).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName) v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), + } + } + + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), } } return InviteV1( - httpReq, request, vars["roomID"], vars["eventID"], + httpReq, request, *roomID, vars["eventID"], cfg, rsAPI, keys, ) }, @@ -162,15 +169,23 @@ func Setup( v2fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), + } + } + + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), } } return InviteV2( - httpReq, request, vars["roomID"], vars["eventID"], + httpReq, request, *roomID, vars["eventID"], cfg, rsAPI, keys, ) }, @@ -184,7 +199,7 @@ func Setup( v1fedmux.Handle("/exchange_third_party_invite/{roomID}", MakeFedAPI( "exchange_third_party_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return ExchangeThirdPartyInvite( httpReq, request, vars["roomID"], rsAPI, cfg, federation, ) @@ -193,7 +208,7 @@ func Setup( v1fedmux.Handle("/event/{eventID}", MakeFedAPI( "federation_get_event", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return GetEvent( httpReq.Context(), request, rsAPI, vars["eventID"], cfg.Matrix.ServerName, ) @@ -202,11 +217,11 @@ func Setup( v1fedmux.Handle("/state/{roomID}", MakeFedAPI( "federation_get_state", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return GetState( @@ -217,11 +232,11 @@ func Setup( v1fedmux.Handle("/state_ids/{roomID}", MakeFedAPI( "federation_get_state_ids", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return GetStateIDs( @@ -232,11 +247,11 @@ func Setup( v1fedmux.Handle("/event_auth/{roomID}/{eventID}", MakeFedAPI( "federation_get_event_auth", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return GetEventAuth( @@ -247,7 +262,7 @@ func Setup( v1fedmux.Handle("/query/directory", MakeFedAPI( "federation_query_room_alias", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return RoomAliasToID( httpReq, federation, cfg, rsAPI, fsAPI, ) @@ -256,7 +271,7 @@ func Setup( v1fedmux.Handle("/query/profile", MakeFedAPI( "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return GetProfile( httpReq, userAPI, cfg, ) @@ -265,7 +280,7 @@ func Setup( v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( httpReq, userAPI, vars["userID"], ) @@ -275,11 +290,11 @@ func Setup( if mscCfg.Enabled("msc2444") { v1fedmux.Handle("/peek/{roomID}/{peekID}", MakeFedAPI( "federation_peek", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } roomID := vars["roomID"] @@ -306,15 +321,13 @@ func Setup( v1fedmux.Handle("/make_join/{roomID}/{userID}", MakeFedAPI( "federation_make_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] - userID := vars["userID"] queryVars := httpReq.URL.Query() remoteVersions := []gomatrixserverlib.RoomVersion{} if vers, ok := queryVars["ver"]; ok { @@ -329,32 +342,56 @@ func Setup( // https://matrix.org/docs/spec/server_server/r0.1.3#get-matrix-federation-v1-make-join-roomid-userid remoteVersions = append(remoteVersions, gomatrixserverlib.RoomVersionV1) } + + userID, err := spec.NewUserID(vars["userID"], true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid UserID"), + } + } + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + + logrus.Debugf("Processing make_join for user %s, room %s", userID.String(), roomID.String()) return MakeJoin( - httpReq, request, cfg, rsAPI, roomID, userID, remoteVersions, + httpReq, request, cfg, rsAPI, *roomID, *userID, remoteVersions, ) }, )).Methods(http.MethodGet) v1fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] eventID := vars["eventID"] + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + res := SendJoin( - httpReq, request, cfg, rsAPI, keys, roomID, eventID, + httpReq, request, cfg, rsAPI, keys, *roomID, eventID, ) // not all responses get wrapped in [code, body] var body interface{} body = []interface{}{ res.Code, res.JSON, } - jerr, ok := res.JSON.(*jsonerror.MatrixError) + jerr, ok := res.JSON.(spec.MatrixError) if ok { body = jerr } @@ -369,45 +406,64 @@ func Setup( v2fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] eventID := vars["eventID"] + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + return SendJoin( - httpReq, request, cfg, rsAPI, keys, roomID, eventID, + httpReq, request, cfg, rsAPI, keys, *roomID, eventID, ) }, )).Methods(http.MethodPut) - v1fedmux.Handle("/make_leave/{roomID}/{eventID}", MakeFedAPI( + v1fedmux.Handle("/make_leave/{roomID}/{userID}", MakeFedAPI( "federation_make_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), + } + } + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + userID, err := spec.NewUserID(vars["userID"], true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid UserID"), } } - roomID := vars["roomID"] - eventID := vars["eventID"] return MakeLeave( - httpReq, request, cfg, rsAPI, roomID, eventID, + httpReq, request, cfg, rsAPI, *roomID, *userID, ) }, )).Methods(http.MethodGet) v1fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } roomID := vars["roomID"] @@ -420,7 +476,7 @@ func Setup( body = []interface{}{ res.Code, res.JSON, } - jerr, ok := res.JSON.(*jsonerror.MatrixError) + jerr, ok := res.JSON.(spec.MatrixError) if ok { body = jerr } @@ -435,11 +491,11 @@ func Setup( v2fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } roomID := vars["roomID"] @@ -459,11 +515,11 @@ func Setup( v1fedmux.Handle("/get_missing_events/{roomID}", MakeFedAPI( "federation_get_missing_events", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return GetMissingEvents(httpReq, request, rsAPI, vars["roomID"]) @@ -472,11 +528,11 @@ func Setup( v1fedmux.Handle("/backfill/{roomID}", MakeFedAPI( "federation_backfill", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Forbidden by server ACLs"), + JSON: spec.Forbidden("Forbidden by server ACLs"), } } return Backfill(httpReq, request, rsAPI, vars["roomID"], cfg) @@ -491,14 +547,14 @@ func Setup( v1fedmux.Handle("/user/keys/claim", MakeFedAPI( "federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return ClaimOneTimeKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) v1fedmux.Handle("/user/keys/query", MakeFedAPI( "federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, - func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + func(httpReq *http.Request, request *fclient.FederationRequest, vars map[string]string) util.JSONResponse { return QueryDeviceKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) @@ -528,7 +584,7 @@ func ErrorIfLocalServerNotInRoom( if !joinedRes.IsInRoom { return &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("This server is not joined to room %s", roomID)), + JSON: spec.NotFound(fmt.Sprintf("This server is not joined to room %s", roomID)), } } return nil @@ -536,14 +592,14 @@ func ErrorIfLocalServerNotInRoom( // MakeFedAPI makes an http.Handler that checks matrix federation authentication. func MakeFedAPI( - metricsName string, serverName gomatrixserverlib.ServerName, - isLocalServerName func(gomatrixserverlib.ServerName) bool, + metricsName string, serverName spec.ServerName, + isLocalServerName func(spec.ServerName) bool, keyRing gomatrixserverlib.JSONVerifier, wakeup *FederationWakeups, - f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, + f func(*http.Request, *fclient.FederationRequest, map[string]string) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { - fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + fedReq, errResp := fclient.VerifyHTTPRequest( req, time.Now(), serverName, isLocalServerName, keyRing, ) if fedReq == nil { @@ -567,7 +623,7 @@ func MakeFedAPI( go wakeup.Wakeup(req.Context(), fedReq.Origin()) vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { - return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") + return util.MatrixErrorResponse(400, string(spec.ErrorUnrecognized), "badly encoded query params") } jsonRes := f(req, fedReq, vars) @@ -586,7 +642,7 @@ type FederationWakeups struct { origins sync.Map } -func (f *FederationWakeups) Wakeup(ctx context.Context, origin gomatrixserverlib.ServerName) { +func (f *FederationWakeups) Wakeup(ctx context.Context, origin spec.ServerName) { key, keyok := f.origins.Load(origin) if keyok { lastTime, ok := key.(time.Time) @@ -594,6 +650,6 @@ func (f *FederationWakeups) Wakeup(ctx context.Context, origin gomatrixserverlib return } } - f.FsAPI.MarkServersAlive([]gomatrixserverlib.ServerName{origin}) + f.FsAPI.MarkServersAlive([]spec.ServerName{origin}) f.origins.Store(origin, time.Now()) } diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 82651719f7..9666945414 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -22,19 +22,19 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) const ( - // Event was passed to the roomserver + // Event was passed to the Roomserver MetricsOutcomeOK = "ok" // Event failed to be processed MetricsOutcomeFail = "fail" @@ -55,15 +55,14 @@ var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse // Send implements /_matrix/federation/v1/send/{txnID} func Send( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, txnID gomatrixserverlib.TransactionID, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, keyAPI userAPI.FederationUserAPI, keys gomatrixserverlib.JSONVerifier, - federation federationAPI.FederationClient, + federation fclient.FederationClient, mu *internal.MutexByRoom, - servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, ) util.JSONResponse { // First we should check if this origin has already submitted this @@ -105,7 +104,7 @@ func Send( if err := json.Unmarshal(request.Content(), &txnEvents); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs. @@ -113,7 +112,7 @@ func Send( if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"), + JSON: spec.BadJSON("max 50 pdus / 100 edus"), } } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 28fa6d6d22..f629479dac 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -32,12 +32,14 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ed25519" ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testOrigin = spec.ServerName("kaer.morhen") ) type sendContent struct { @@ -64,14 +66,14 @@ func TestHandleSend(t *testing.T) { if !ok { panic("This is a programming error.") } - routing.Setup(routers, cfg, nil, r, keyRing, nil, nil, &cfg.MSCs, nil, nil, caching.DisableMetrics) + routing.Setup(routers, cfg, nil, r, keyRing, nil, nil, &cfg.MSCs, nil, caching.DisableMetrics) handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) keyID := signing.KeyID pk := sk.Public().(ed25519.PublicKey) - serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - req := gomatrixserverlib.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234") + serverName := spec.ServerName(hex.EncodeToString(pk)) + req := fclient.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234") content := sendContent{} err := req.SetContent(content) if err != nil { diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 1152c09323..11ad1ebfc8 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -17,17 +17,17 @@ import ( "net/http" "net/url" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) // GetState returns state events & auth events for the roomID, eventID func GetState( ctx context.Context, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, ) util.JSONResponse { @@ -42,15 +42,15 @@ func GetState( } return util.JSONResponse{Code: http.StatusOK, JSON: &fclient.RespState{ - AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(authChain), - StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateEvents), + AuthEvents: types.NewEventJSONsFromHeaderedEvents(authChain), + StateEvents: types.NewEventJSONsFromHeaderedEvents(stateEvents), }} } // GetStateIDs returns state event IDs & auth event IDs for the roomID, eventID func GetStateIDs( ctx context.Context, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, ) util.JSONResponse { @@ -75,7 +75,7 @@ func GetStateIDs( } func parseEventIDParam( - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, ) (eventID string, resErr *util.JSONResponse) { URL, err := url.Parse(request.RequestURI()) if err != nil { @@ -88,7 +88,7 @@ func parseEventIDParam( if eventID == "" { resErr = &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("event_id missing"), + JSON: spec.MissingParam("event_id missing"), } } @@ -97,11 +97,11 @@ func parseEventIDParam( func getState( ctx context.Context, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, rsAPI api.FederationRoomserverAPI, roomID string, eventID string, -) (stateEvents, authEvents []*gomatrixserverlib.HeaderedEvent, errRes *util.JSONResponse) { +) (stateEvents, authEvents []*types.HeaderedEvent, errRes *util.JSONResponse) { // If we don't think we belong to this room then don't waste the effort // responding to expensive requests for it. if err := ErrorIfLocalServerNotInRoom(ctx, rsAPI, roomID); err != nil { @@ -114,9 +114,9 @@ func getState( } if event.RoomID() != roomID { - return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: jsonerror.NotFound("event does not belong to this room")} + return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")} } - resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) + resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID()) if resErr != nil { return nil, nil, resErr } @@ -140,24 +140,24 @@ func getState( case !response.RoomExists: return nil, nil, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Room not found"), + JSON: spec.NotFound("Room not found"), } case !response.StateKnown: return nil, nil, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("State not known"), + JSON: spec.NotFound("State not known"), } case response.IsRejected: return nil, nil, &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Event not found"), + JSON: spec.NotFound("Event not found"), } } return response.StateEvents, response.AuthChainEvents, nil } -func getIDsFromEvent(events []*gomatrixserverlib.HeaderedEvent) []string { +func getIDsFromEvent(events []*types.HeaderedEvent) []string { IDs := make([]string, len(events)) for i := range events { IDs[i] = events[i].EventID() diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 048183ad14..42ba8bfe52 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -22,15 +22,14 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - "github.com/sirupsen/logrus" ) @@ -58,7 +57,7 @@ var ( func CreateInvitesFrom3PIDInvites( req *http.Request, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, - federation federationAPI.FederationClient, + federation fclient.FederationClient, userAPI userapi.FederationUserAPI, ) util.JSONResponse { var body invites @@ -66,14 +65,13 @@ func CreateInvitesFrom3PIDInvites( return *reqErr } - evs := []*gomatrixserverlib.HeaderedEvent{} + evs := []*types.HeaderedEvent{} for _, inv := range body.Invites { - verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := rsAPI.QueryRoomVersionForRoom(req.Context(), &verReq, &verRes); err != nil { + _, err := rsAPI.QueryRoomVersionForRoom(req.Context(), inv.RoomID) + if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(err.Error()), + JSON: spec.UnsupportedRoomVersion(err.Error()), } } @@ -82,10 +80,13 @@ func CreateInvitesFrom3PIDInvites( ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("createInviteFrom3PIDInvite failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if event != nil { - evs = append(evs, event.Headered(verRes.RoomVersion)) + evs = append(evs, &types.HeaderedEvent{PDU: event}) } } @@ -102,7 +103,10 @@ func CreateInvitesFrom3PIDInvites( false, ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -114,98 +118,126 @@ func CreateInvitesFrom3PIDInvites( // ExchangeThirdPartyInvite implements PUT /_matrix/federation/v1/exchange_third_party_invite/{roomID} func ExchangeThirdPartyInvite( httpReq *http.Request, - request *gomatrixserverlib.FederationRequest, + request *fclient.FederationRequest, roomID string, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, - federation federationAPI.FederationClient, + federation fclient.FederationClient, ) util.JSONResponse { - var builder gomatrixserverlib.EventBuilder - if err := json.Unmarshal(request.Content(), &builder); err != nil { + var proto gomatrixserverlib.ProtoEvent + if err := json.Unmarshal(request.Content(), &proto); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } // Check that the room ID is correct. - if builder.RoomID != roomID { + if proto.RoomID != roomID { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The room ID in the request path must match the room ID in the invite event JSON"), + JSON: spec.BadJSON("The room ID in the request path must match the room ID in the invite event JSON"), } } - _, senderDomain, err := cfg.Matrix.SplitLocalID('@', builder.Sender) + validRoomID, err := spec.NewRoomID(roomID) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Invalid sender ID: " + err.Error()), + JSON: spec.BadJSON("Invalid room ID"), } } + userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(proto.SenderID)) + if err != nil || userID == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("Invalid sender ID"), + } + } + senderDomain := userID.Domain() // Check that the state key is correct. - _, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey) - if err != nil { + targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(*proto.StateKey)) + if err != nil || targetUserID == nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event's state key isn't a Matrix user ID"), + JSON: spec.BadJSON("The event's state key isn't a Matrix user ID"), } } + targetDomain := targetUserID.Domain() // Check that the target user is from the requesting homeserver. if targetDomain != request.Origin() { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The event's state key doesn't have the same domain as the request's origin"), + JSON: spec.BadJSON("The event's state key doesn't have the same domain as the request's origin"), } } - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err = rsAPI.QueryRoomVersionForRoom(httpReq.Context(), &verReq, &verRes); err != nil { + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID) + if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.UnsupportedRoomVersion(err.Error()), + JSON: spec.UnsupportedRoomVersion(err.Error()), } } // Auth and build the event from what the remote server sent us - event, err := buildMembershipEvent(httpReq.Context(), &builder, rsAPI, cfg) + event, err := buildMembershipEvent(httpReq.Context(), &proto, rsAPI, cfg) if err == errNotInRoom { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown room " + roomID), + JSON: spec.NotFound("Unknown room " + roomID), } } else if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("buildMembershipEvent failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Ask the requesting server to sign the newly created event so we know it // acknowledged it - inviteReq, err := gomatrixserverlib.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil) + inviteReq, err := fclient.NewInviteV2Request(event, nil) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - inviteEvent, err := signedEvent.Event.UntrustedEvent(verRes.RoomVersion) + verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Errorf("unknown room version: %s", roomVersion) + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + inviteEvent, err := verImpl.NewEventFromUntrustedJSON(signedEvent.Event) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - // Send the event to the roomserver + // Send the event to the Roomserver if err = api.SendEvents( httpReq.Context(), rsAPI, api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - inviteEvent.Headered(verRes.RoomVersion), + []*types.HeaderedEvent{ + {PDU: inviteEvent}, }, request.Destination(), request.Origin(), @@ -214,7 +246,10 @@ func ExchangeThirdPartyInvite( false, ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ @@ -230,15 +265,9 @@ func ExchangeThirdPartyInvite( func createInviteFrom3PIDInvite( ctx context.Context, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, - inv invite, federation federationAPI.FederationClient, + inv invite, federation fclient.FederationClient, userAPI userapi.FederationUserAPI, -) (*gomatrixserverlib.Event, error) { - verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - return nil, err - } - +) (gomatrixserverlib.PDU, error) { _, server, err := gomatrixserverlib.SplitID('@', inv.MXID) if err != nil { return nil, err @@ -249,9 +278,9 @@ func createInviteFrom3PIDInvite( } // Build the event - builder := &gomatrixserverlib.EventBuilder{ + proto := &gomatrixserverlib.ProtoEvent{ Type: "m.room.member", - Sender: inv.Sender, + SenderID: inv.Sender, RoomID: inv.RoomID, StateKey: &inv.MXID, } @@ -264,19 +293,19 @@ func createInviteFrom3PIDInvite( content := gomatrixserverlib.MemberContent{ AvatarURL: profile.AvatarURL, DisplayName: profile.DisplayName, - Membership: gomatrixserverlib.Invite, + Membership: spec.Invite, ThirdPartyInvite: &gomatrixserverlib.MemberThirdPartyInvite{ Signed: inv.Signed, }, } - if err = builder.SetContent(content); err != nil { + if err = proto.SetContent(content); err != nil { return nil, err } - event, err := buildMembershipEvent(ctx, builder, rsAPI, cfg) + event, err := buildMembershipEvent(ctx, proto, rsAPI, cfg) if err == errNotInRoom { - return nil, sendToRemoteServer(ctx, inv, federation, cfg, *builder) + return nil, sendToRemoteServer(ctx, inv, federation, cfg, *proto) } if err != nil { return nil, err @@ -292,10 +321,10 @@ func createInviteFrom3PIDInvite( // Returns an error if something failed during the process. func buildMembershipEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, rsAPI api.FederationRoomserverAPI, + protoEvent *gomatrixserverlib.ProtoEvent, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, -) (*gomatrixserverlib.Event, error) { - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) +) (gomatrixserverlib.PDU, error) { + eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(protoEvent) if err != nil { return nil, err } @@ -304,9 +333,9 @@ func buildMembershipEvent( return nil, errors.New("expecting state tuples for event builder, got none") } - // Ask the roomserver for information about this room + // Ask the Roomserver for information about this room queryReq := api.QueryLatestEventsAndStateRequest{ - RoomID: builder.RoomID, + RoomID: protoEvent.RoomID, StateToFetch: eventsNeeded.Tuples(), } var queryRes api.QueryLatestEventsAndStateResponse @@ -320,19 +349,19 @@ func buildMembershipEvent( } // Auth the event locally - builder.Depth = queryRes.Depth - builder.PrevEvents = queryRes.LatestEvents + protoEvent.Depth = queryRes.Depth + protoEvent.PrevEvents = queryRes.LatestEvents authEvents := gomatrixserverlib.NewAuthEvents(nil) for i := range queryRes.StateEvents { - err = authEvents.AddEvent(queryRes.StateEvents[i].Event) + err = authEvents.AddEvent(queryRes.StateEvents[i].PDU) if err != nil { return nil, err } } - if err = fillDisplayName(builder, authEvents); err != nil { + if err = fillDisplayName(protoEvent, authEvents); err != nil { return nil, err } @@ -340,11 +369,17 @@ func buildMembershipEvent( if err != nil { return nil, err } - builder.AuthEvents = refs + protoEvent.AuthEvents = refs + + verImpl, err := gomatrixserverlib.GetRoomVersion(queryRes.RoomVersion) + if err != nil { + return nil, err + } + builder := verImpl.NewEventBuilderFromProtoEvent(protoEvent) event, err := builder.Build( time.Now(), cfg.Matrix.ServerName, cfg.Matrix.KeyID, - cfg.Matrix.PrivateKey, queryRes.RoomVersion, + cfg.Matrix.PrivateKey, ) return event, err @@ -357,10 +392,10 @@ func buildMembershipEvent( // them responded with an error. func sendToRemoteServer( ctx context.Context, inv invite, - federation federationAPI.FederationClient, cfg *config.FederationAPI, - builder gomatrixserverlib.EventBuilder, + federation fclient.FederationClient, cfg *config.FederationAPI, + proto gomatrixserverlib.ProtoEvent, ) (err error) { - remoteServers := make([]gomatrixserverlib.ServerName, 2) + remoteServers := make([]spec.ServerName, 2) _, remoteServers[0], err = gomatrixserverlib.SplitID('@', inv.Sender) if err != nil { return @@ -373,7 +408,7 @@ func sendToRemoteServer( } for _, server := range remoteServers { - err = federation.ExchangeThirdPartyInvite(ctx, cfg.Matrix.ServerName, server, builder) + err = federation.ExchangeThirdPartyInvite(ctx, cfg.Matrix.ServerName, server, proto) if err == nil { return } @@ -394,7 +429,7 @@ func sendToRemoteServer( // found. Returning an error isn't necessary in this case as the event will be // rejected by gomatrixserverlib. func fillDisplayName( - builder *gomatrixserverlib.EventBuilder, authEvents gomatrixserverlib.AuthEvents, + builder *gomatrixserverlib.ProtoEvent, authEvents gomatrixserverlib.AuthEvents, ) error { var content gomatrixserverlib.MemberContent if err := json.Unmarshal(builder.Content, &content); err != nil { diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index e29e3b140d..e5fc4b940f 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -7,11 +7,11 @@ import ( "sync" "time" - "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "go.uber.org/atomic" "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Statistics contains information about all of the remote federated @@ -19,10 +19,10 @@ import ( // wrapper. type Statistics struct { DB storage.Database - servers map[gomatrixserverlib.ServerName]*ServerStatistics + servers map[spec.ServerName]*ServerStatistics mutex sync.RWMutex - backoffTimers map[gomatrixserverlib.ServerName]*time.Timer + backoffTimers map[spec.ServerName]*time.Timer backoffMutex sync.RWMutex // How many times should we tolerate consecutive failures before we @@ -45,14 +45,14 @@ func NewStatistics( DB: db, FailuresUntilBlacklist: failuresUntilBlacklist, FailuresUntilAssumedOffline: failuresUntilAssumedOffline, - backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), - servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), + backoffTimers: make(map[spec.ServerName]*time.Timer), + servers: make(map[spec.ServerName]*ServerStatistics), } } // ForServer returns server statistics for the given server name. If it // does not exist, it will create empty statistics and return those. -func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { +func (s *Statistics) ForServer(serverName spec.ServerName) *ServerStatistics { // Look up if we have statistics for this server already. s.mutex.RLock() server, found := s.servers[serverName] @@ -63,7 +63,7 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS server = &ServerStatistics{ statistics: s, serverName: serverName, - knownRelayServers: []gomatrixserverlib.ServerName{}, + knownRelayServers: []spec.ServerName{}, } s.servers[serverName] = server s.mutex.Unlock() @@ -104,17 +104,17 @@ const ( // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - assumedOffline atomic.Bool // is the node assumed to be offline - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - successCounter atomic.Uint32 // how many times have we succeeded? - backoffNotifier func() // notifies destination queue when backoff completes + statistics *Statistics // + serverName spec.ServerName // + blacklisted atomic.Bool // is the node blacklisted + assumedOffline atomic.Bool // is the node assumed to be offline + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes notifierMutex sync.Mutex - knownRelayServers []gomatrixserverlib.ServerName + knownRelayServers []spec.ServerName relayMutex sync.Mutex } @@ -307,15 +307,15 @@ func (s *ServerStatistics) SuccessCount() uint32 { // KnownRelayServers returns the list of relay servers associated with this // server. -func (s *ServerStatistics) KnownRelayServers() []gomatrixserverlib.ServerName { +func (s *ServerStatistics) KnownRelayServers() []spec.ServerName { s.relayMutex.Lock() defer s.relayMutex.Unlock() return s.knownRelayServers } -func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.ServerName) { - seenSet := make(map[gomatrixserverlib.ServerName]bool) - uniqueList := []gomatrixserverlib.ServerName{} +func (s *ServerStatistics) AddRelayServers(relayServers []spec.ServerName) { + seenSet := make(map[spec.ServerName]bool) + uniqueList := []spec.ServerName{} for _, srv := range relayServers { if seenSet[srv] { continue diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 183b9aa0c3..a930bc3b08 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -6,7 +6,7 @@ import ( "time" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -108,10 +108,10 @@ func TestBackoff(t *testing.T) { func TestRelayServersListing(t *testing.T) { stats := NewStatistics(test.NewInMemoryFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline) server := ServerStatistics{statistics: &stats} - server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + server.AddRelayServers([]spec.ServerName{"relay1", "relay1", "relay2"}) relayServers := server.KnownRelayServers() - assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) - server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + assert.Equal(t, []spec.ServerName{"relay1", "relay2"}, relayServers) + server.AddRelayServers([]spec.ServerName{"relay1", "relay1", "relay2"}) relayServers = server.KnownRelayServers() - assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) + assert.Equal(t, []spec.ServerName{"relay1", "relay2"}, relayServers) } diff --git a/federationapi/storage/cache/keydb.go b/federationapi/storage/cache/keydb.go index 2063dfc55f..b53695ca47 100644 --- a/federationapi/storage/cache/keydb.go +++ b/federationapi/storage/cache/keydb.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // A Database implements gomatrixserverlib.KeyDatabase and is used to store @@ -36,7 +37,7 @@ func (d KeyDatabase) FetcherName() string { // FetchKeys implements gomatrixserverlib.KeyDatabase func (d *KeyDatabase) FetchKeys( ctx context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) for req, ts := range requests { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 4f5300af16..5388b4d2be 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -19,9 +19,11 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/types" + rstypes "github.com/matrix-org/dendrite/roomserver/types" ) type Database interface { @@ -31,57 +33,57 @@ type Database interface { UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) - GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + GetAllJoinedHosts(ctx context.Context) ([]spec.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. - GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) + GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]spec.ServerName, error) StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error) - GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) - GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) + GetPendingPDUs(ctx context.Context, serverName spec.ServerName, limit int) (pdus map[*receipt.Receipt]*rstypes.HeaderedEvent, err error) + GetPendingEDUs(ctx context.Context, serverName spec.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error - AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error - CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error - CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error + CleanPDUs(ctx context.Context, serverName spec.ServerName, receipts []*receipt.Receipt) error + CleanEDUs(ctx context.Context, serverName spec.ServerName, receipts []*receipt.Receipt) error - GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + GetPendingPDUServerNames(ctx context.Context) ([]spec.ServerName, error) + GetPendingEDUServerNames(ctx context.Context) ([]spec.ServerName, error) // these don't have contexts passed in as we want things to happen regardless of the request context - AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error - RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error + AddServerToBlacklist(serverName spec.ServerName) error + RemoveServerFromBlacklist(serverName spec.ServerName) error RemoveAllServersFromBlacklist() error - IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) + IsServerBlacklisted(serverName spec.ServerName) (bool, error) // Adds the server to the list of assumed offline servers. // If the server already exists in the table, nothing happens and returns success. - SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + SetServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error // Removes the server from the list of assumed offline servers. // If the server doesn't exist in the table, nothing happens and returns success. - RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + RemoveServerAssumedOffline(ctx context.Context, serverName spec.ServerName) error // Purges all entries from the assumed offline table. RemoveAllServersAssumedOffline(ctx context.Context) error // Gets whether the provided server is present in the table. // If it is present, returns true. If not, returns false. - IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error) + IsServerAssumedOffline(ctx context.Context, serverName spec.ServerName) (bool, error) - AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error - RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error - GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) + AddOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error + RenewOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error + GetOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string) (*types.OutboundPeek, error) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) - AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error - RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error - GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) + AddInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error + RenewInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error + GetInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string) (*types.InboundPeek, error) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) // Update the notary with the given server keys from the given server name. - UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error + UpdateNotaryKeys(ctx context.Context, serverName spec.ServerName, serverKeys gomatrixserverlib.ServerKeys) error // Query the notary for the server keys for the given server. If `optKeyIDs` is not empty, multiple server keys may be returned (between 1 - len(optKeyIDs)) // such that the combination of all server keys will include all the `optKeyIDs`. - GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) + GetNotaryKeys(ctx context.Context, serverName spec.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) // DeleteExpiredEDUs cleans up expired EDUs DeleteExpiredEDUs(ctx context.Context) error @@ -91,17 +93,17 @@ type Database interface { type P2PDatabase interface { // Stores the given list of servers as relay servers for the provided destination server. // Providing duplicates will only lead to a single entry and won't lead to an error. - P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + P2PAddRelayServersForServer(ctx context.Context, serverName spec.ServerName, relayServers []spec.ServerName) error // Get the list of relay servers associated with the provided destination server. // If no entry exists in the table, an empty list is returned and does not result in an error. - P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + P2PGetRelayServersForServer(ctx context.Context, serverName spec.ServerName) ([]spec.ServerName, error) // Deletes any entries for the provided destination server that match the provided relayServers list. // If any of the provided servers don't match an entry, nothing happens and no error is returned. - P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + P2PRemoveRelayServersForServer(ctx context.Context, serverName spec.ServerName, relayServers []spec.ServerName) error // Deletes all entries for the provided destination server. // If the destination server doesn't exist in the table, nothing happens and no error is returned. - P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error + P2PRemoveAllRelayServersForServer(ctx context.Context, serverName spec.ServerName) error } diff --git a/federationapi/storage/postgres/assumed_offline_table.go b/federationapi/storage/postgres/assumed_offline_table.go index 5695d2e549..d8d389d866 100644 --- a/federationapi/storage/postgres/assumed_offline_table.go +++ b/federationapi/storage/postgres/assumed_offline_table.go @@ -19,7 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const assumedOfflineSchema = ` @@ -68,7 +68,7 @@ func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, er } func (s *assumedOfflineStatements) InsertAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) _, err := stmt.ExecContext(ctx, serverName) @@ -76,7 +76,7 @@ func (s *assumedOfflineStatements) InsertAssumedOffline( } func (s *assumedOfflineStatements) SelectAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (bool, error) { stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) res, err := stmt.QueryContext(ctx, serverName) @@ -91,7 +91,7 @@ func (s *assumedOfflineStatements) SelectAssumedOffline( } func (s *assumedOfflineStatements) DeleteAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) _, err := stmt.ExecContext(ctx, serverName) diff --git a/federationapi/storage/postgres/blacklist_table.go b/federationapi/storage/postgres/blacklist_table.go index 1d931daa3b..48b6d72e1e 100644 --- a/federationapi/storage/postgres/blacklist_table.go +++ b/federationapi/storage/postgres/blacklist_table.go @@ -19,7 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const blacklistSchema = ` @@ -69,7 +69,7 @@ func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { } func (s *blacklistStatements) InsertBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) _, err := stmt.ExecContext(ctx, serverName) @@ -77,7 +77,7 @@ func (s *blacklistStatements) InsertBlacklist( } func (s *blacklistStatements) SelectBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (bool, error) { stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) res, err := stmt.QueryContext(ctx, serverName) @@ -92,7 +92,7 @@ func (s *blacklistStatements) SelectBlacklist( } func (s *blacklistStatements) DeleteBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) _, err := stmt.ExecContext(ctx, serverName) diff --git a/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go b/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go index 53a7a025e8..cf2d94b205 100644 --- a/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go +++ b/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go @@ -20,7 +20,7 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error { @@ -28,7 +28,7 @@ func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error { if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } - _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24))) + _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", spec.AsTimestamp(time.Now().Add(time.Hour*24))) if err != nil { return fmt.Errorf("failed to update queue_edus: %w", err) } diff --git a/federationapi/storage/postgres/inbound_peeks_table.go b/federationapi/storage/postgres/inbound_peeks_table.go index ad2afcb15f..a6fffc0e12 100644 --- a/federationapi/storage/postgres/inbound_peeks_table.go +++ b/federationapi/storage/postgres/inbound_peeks_table.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const inboundPeeksSchema = ` @@ -86,7 +86,7 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er } func (s *inboundPeeksStatements) InsertInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt) @@ -95,7 +95,7 @@ func (s *inboundPeeksStatements) InsertInboundPeek( } func (s *inboundPeeksStatements) RenewInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) @@ -103,7 +103,7 @@ func (s *inboundPeeksStatements) RenewInboundPeek( } func (s *inboundPeeksStatements) SelectInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (*types.InboundPeek, error) { row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID) inboundPeek := types.InboundPeek{} @@ -152,7 +152,7 @@ func (s *inboundPeeksStatements) SelectInboundPeeks( } func (s *inboundPeeksStatements) DeleteInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) return diff --git a/federationapi/storage/postgres/joined_hosts_table.go b/federationapi/storage/postgres/joined_hosts_table.go index 8806db550c..2b0aebad1f 100644 --- a/federationapi/storage/postgres/joined_hosts_table.go +++ b/federationapi/storage/postgres/joined_hosts_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const joinedHostsSchema = ` @@ -105,7 +105,7 @@ func (s *joinedHostsStatements) InsertJoinedHosts( ctx context.Context, txn *sql.Tx, roomID, eventID string, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) @@ -143,20 +143,20 @@ func (s *joinedHostsStatements) SelectJoinedHosts( func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var serverName string if err = rows.Scan(&serverName); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(serverName)) + result = append(result, spec.ServerName(serverName)) } return result, rows.Err() @@ -164,7 +164,7 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( func (s *joinedHostsStatements) SelectJoinedHostsForRooms( ctx context.Context, roomIDs []string, excludingBlacklisted bool, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { stmt := s.selectJoinedHostsForRoomsStmt if excludingBlacklisted { stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt @@ -175,13 +175,13 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms( } defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var serverName string if err = rows.Scan(&serverName); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(serverName)) + result = append(result, spec.ServerName(serverName)) } return result, rows.Err() @@ -204,7 +204,7 @@ func joinedHostsFromStmt( } result = append(result, types.JoinedHost{ MemberEventID: eventID, - ServerName: gomatrixserverlib.ServerName(serverName), + ServerName: spec.ServerName(serverName), }) } diff --git a/federationapi/storage/postgres/notary_server_keys_json_table.go b/federationapi/storage/postgres/notary_server_keys_json_table.go index 9fc93a612d..af98a0d4e0 100644 --- a/federationapi/storage/postgres/notary_server_keys_json_table.go +++ b/federationapi/storage/postgres/notary_server_keys_json_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const notaryServerKeysJSONSchema = ` @@ -57,7 +58,7 @@ func NewPostgresNotaryServerKeysTable(db *sql.DB) (s *notaryServerKeysStatements } func (s *notaryServerKeysStatements) InsertJSONResponse( - ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp, + ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName spec.ServerName, validUntil spec.Timestamp, ) (tables.NotaryID, error) { var notaryID tables.NotaryID return notaryID, txn.Stmt(s.insertServerKeysJSONStmt).QueryRowContext(ctx, string(keyQueryResponseJSON.Raw), serverName, validUntil).Scan(¬aryID) diff --git a/federationapi/storage/postgres/notary_server_keys_metadata_table.go b/federationapi/storage/postgres/notary_server_keys_metadata_table.go index 6d38ccab5f..7a1ec41220 100644 --- a/federationapi/storage/postgres/notary_server_keys_metadata_table.go +++ b/federationapi/storage/postgres/notary_server_keys_metadata_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const notaryServerKeysMetadataSchema = ` @@ -102,12 +103,12 @@ func NewPostgresNotaryServerKeysMetadataTable(db *sql.DB) (s *notaryServerKeysMe } func (s *notaryServerKeysMetadataStatements) UpsertKey( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil gomatrixserverlib.Timestamp, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil spec.Timestamp, ) (tables.NotaryID, error) { notaryID := newNotaryID // see if the existing notary ID a) exists, b) has a longer valid_until var existingNotaryID tables.NotaryID - var existingValidUntil gomatrixserverlib.Timestamp + var existingValidUntil spec.Timestamp if err := txn.Stmt(s.selectNotaryKeyMetadataStmt).QueryRowContext(ctx, serverName, keyID).Scan(&existingNotaryID, &existingValidUntil); err != nil { if err != sql.ErrNoRows { return 0, err @@ -122,7 +123,7 @@ func (s *notaryServerKeysMetadataStatements) UpsertKey( return notaryID, err } -func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { +func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { var rows *sql.Rows var err error if len(keyIDs) == 0 { diff --git a/federationapi/storage/postgres/outbound_peeks_table.go b/federationapi/storage/postgres/outbound_peeks_table.go index 5df6843183..bd2b10e674 100644 --- a/federationapi/storage/postgres/outbound_peeks_table.go +++ b/federationapi/storage/postgres/outbound_peeks_table.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const outboundPeeksSchema = ` @@ -85,7 +85,7 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err } func (s *outboundPeeksStatements) InsertOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt) @@ -94,7 +94,7 @@ func (s *outboundPeeksStatements) InsertOutboundPeek( } func (s *outboundPeeksStatements) RenewOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) @@ -102,7 +102,7 @@ func (s *outboundPeeksStatements) RenewOutboundPeek( } func (s *outboundPeeksStatements) SelectOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (*types.OutboundPeek, error) { row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) outboundPeek := types.OutboundPeek{} @@ -151,7 +151,7 @@ func (s *outboundPeeksStatements) SelectOutboundPeeks( } func (s *outboundPeeksStatements) DeleteOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) return diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go index 8870dc88d6..7c57ed0cc4 100644 --- a/federationapi/storage/postgres/queue_edus_table.go +++ b/federationapi/storage/postgres/queue_edus_table.go @@ -19,11 +19,11 @@ import ( "database/sql" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/federationapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" ) const queueEDUsSchema = ` @@ -121,9 +121,9 @@ func (s *queueEDUsStatements) InsertQueueEDU( ctx context.Context, txn *sql.Tx, eduType string, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, - expiresAt gomatrixserverlib.Timestamp, + expiresAt spec.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) _, err := stmt.ExecContext( @@ -138,7 +138,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( func (s *queueEDUsStatements) DeleteQueueEDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { stmt := sqlutil.TxStmt(txn, s.deleteQueueEDUStmt) @@ -148,7 +148,7 @@ func (s *queueEDUsStatements) DeleteQueueEDUs( func (s *queueEDUsStatements) SelectQueueEDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) @@ -182,16 +182,16 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName if err = rows.Scan(&serverName); err != nil { return nil, err } @@ -203,7 +203,7 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( func (s *queueEDUsStatements) SelectExpiredEDUs( ctx context.Context, txn *sql.Tx, - expiredBefore gomatrixserverlib.Timestamp, + expiredBefore spec.Timestamp, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) rows, err := stmt.QueryContext(ctx, expiredBefore) @@ -224,7 +224,7 @@ func (s *queueEDUsStatements) SelectExpiredEDUs( func (s *queueEDUsStatements) DeleteExpiredEDUs( ctx context.Context, txn *sql.Tx, - expiredBefore gomatrixserverlib.Timestamp, + expiredBefore spec.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt) _, err := stmt.ExecContext(ctx, expiredBefore) diff --git a/federationapi/storage/postgres/queue_pdus_table.go b/federationapi/storage/postgres/queue_pdus_table.go index b97be4822c..a767ec41d7 100644 --- a/federationapi/storage/postgres/queue_pdus_table.go +++ b/federationapi/storage/postgres/queue_pdus_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const queuePDUsSchema = ` @@ -91,7 +92,7 @@ func (s *queuePDUsStatements) InsertQueuePDU( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) @@ -106,7 +107,7 @@ func (s *queuePDUsStatements) InsertQueuePDU( func (s *queuePDUsStatements) DeleteQueuePDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { stmt := sqlutil.TxStmt(txn, s.deleteQueuePDUsStmt) @@ -131,7 +132,7 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt) @@ -154,16 +155,16 @@ func (s *queuePDUsStatements) SelectQueuePDUs( func (s *queuePDUsStatements) SelectQueuePDUServerNames( ctx context.Context, txn *sql.Tx, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectQueuePDUServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName if err = rows.Scan(&serverName); err != nil { return nil, err } diff --git a/federationapi/storage/postgres/relay_servers_table.go b/federationapi/storage/postgres/relay_servers_table.go index f7267978f1..9e1bc5d404 100644 --- a/federationapi/storage/postgres/relay_servers_table.go +++ b/federationapi/storage/postgres/relay_servers_table.go @@ -21,7 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const relayServersSchema = ` @@ -78,8 +78,8 @@ func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err er func (s *relayServersStatements) InsertRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { for _, relayServer := range relayServers { stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) @@ -93,8 +93,8 @@ func (s *relayServersStatements) InsertRelayServers( func (s *relayServersStatements) SelectRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, -) ([]gomatrixserverlib.ServerName, error) { + serverName spec.ServerName, +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) rows, err := stmt.QueryContext(ctx, serverName) if err != nil { @@ -102,13 +102,13 @@ func (s *relayServersStatements) SelectRelayServers( } defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var relayServer string if err = rows.Scan(&relayServer); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(relayServer)) + result = append(result, spec.ServerName(relayServer)) } return result, nil } @@ -116,8 +116,8 @@ func (s *relayServersStatements) SelectRelayServers( func (s *relayServersStatements) DeleteRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) _, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers)) @@ -127,7 +127,7 @@ func (s *relayServersStatements) DeleteRelayServers( func (s *relayServersStatements) DeleteAllRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) if _, err := stmt.ExecContext(ctx, serverName); err != nil { diff --git a/federationapi/storage/postgres/server_key_table.go b/federationapi/storage/postgres/server_key_table.go index f393351bb2..c62446da53 100644 --- a/federationapi/storage/postgres/server_key_table.go +++ b/federationapi/storage/postgres/server_key_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const serverSigningKeysSchema = ` @@ -80,7 +81,7 @@ func NewPostgresServerSigningKeysTable(db *sql.DB) (s *serverSigningKeyStatement func (s *serverSigningKeyStatements) BulkSelectServerKeys( ctx context.Context, txn *sql.Tx, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { var nameAndKeyIDs []string for request := range requests { @@ -103,7 +104,7 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys( return nil, err } r := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: gomatrixserverlib.ServerName(serverName), + ServerName: spec.ServerName(serverName), KeyID: gomatrixserverlib.KeyID(keyID), } vk := gomatrixserverlib.VerifyKey{} @@ -113,8 +114,8 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys( } results[r] = gomatrixserverlib.PublicKeyLookupResult{ VerifyKey: vk, - ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), - ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), + ValidUntilTS: spec.Timestamp(validUntilTS), + ExpiredTS: spec.Timestamp(expiredTS), } } return results, rows.Err() diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 468567cf09..30665bc56b 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -25,7 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Database stores information needed by the federation sender @@ -36,7 +36,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) { var d Database var err error if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 6769637bcf..8c73967c6f 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -26,11 +26,12 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Database struct { DB *sql.DB - IsLocalServerName func(gomatrixserverlib.ServerName) bool + IsLocalServerName func(spec.ServerName) bool Cache caching.FederationCache Writer sqlutil.Writer FederationQueuePDUs tables.FederationQueuePDUs @@ -102,7 +103,7 @@ func (d *Database) GetJoinedHosts( // Returns an error if something goes wrong. func (d *Database) GetAllJoinedHosts( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } @@ -111,7 +112,7 @@ func (d *Database) GetJoinedHostsForRooms( roomIDs []string, excludeSelf, excludeBlacklisted bool, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err @@ -148,7 +149,7 @@ func (d *Database) StoreJSON( } func (d *Database) AddServerToBlacklist( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) @@ -156,7 +157,7 @@ func (d *Database) AddServerToBlacklist( } func (d *Database) RemoveServerFromBlacklist( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName) @@ -170,14 +171,14 @@ func (d *Database) RemoveAllServersFromBlacklist() error { } func (d *Database) IsServerBlacklisted( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (bool, error) { return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } func (d *Database) SetServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName) @@ -186,7 +187,7 @@ func (d *Database) SetServerAssumedOffline( func (d *Database) RemoveServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName) @@ -203,15 +204,15 @@ func (d *Database) RemoveAllServersAssumedOffline( func (d *Database) IsServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (bool, error) { return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName) } func (d *Database) P2PAddRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers) @@ -220,15 +221,15 @@ func (d *Database) P2PAddRelayServersForServer( func (d *Database) P2PGetRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, -) ([]gomatrixserverlib.ServerName, error) { + serverName spec.ServerName, +) ([]spec.ServerName, error) { return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName) } func (d *Database) P2PRemoveRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers) @@ -237,7 +238,7 @@ func (d *Database) P2PRemoveRelayServersForServer( func (d *Database) P2PRemoveAllRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName) @@ -246,7 +247,7 @@ func (d *Database) P2PRemoveAllRelayServersForServer( func (d *Database) AddOutboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID string, peekID string, renewalInterval int64, @@ -258,7 +259,7 @@ func (d *Database) AddOutboundPeek( func (d *Database) RenewOutboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID string, peekID string, renewalInterval int64, @@ -270,7 +271,7 @@ func (d *Database) RenewOutboundPeek( func (d *Database) GetOutboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID, peekID string, ) (*types.OutboundPeek, error) { @@ -286,7 +287,7 @@ func (d *Database) GetOutboundPeeks( func (d *Database) AddInboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID string, peekID string, renewalInterval int64, @@ -298,7 +299,7 @@ func (d *Database) AddInboundPeek( func (d *Database) RenewInboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID string, peekID string, renewalInterval int64, @@ -310,7 +311,7 @@ func (d *Database) RenewInboundPeek( func (d *Database) GetInboundPeek( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, roomID string, peekID string, ) (*types.InboundPeek, error) { @@ -326,7 +327,7 @@ func (d *Database) GetInboundPeeks( func (d *Database) UpdateNotaryKeys( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, serverKeys gomatrixserverlib.ServerKeys, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -337,7 +338,7 @@ func (d *Database) UpdateNotaryKeys( // https://spec.matrix.org/unstable/server-server-api/#querying-keys-through-another-server weekIntoFuture := time.Now().Add(7 * 24 * time.Hour) if weekIntoFuture.Before(validUntil.Time()) { - validUntil = gomatrixserverlib.AsTimestamp(weekIntoFuture) + validUntil = spec.AsTimestamp(weekIntoFuture) } notaryID, err := d.NotaryServerKeysJSON.InsertJSONResponse(ctx, txn, serverKeys, serverName, validUntil) if err != nil { @@ -364,7 +365,7 @@ func (d *Database) UpdateNotaryKeys( func (d *Database) GetNotaryKeys( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, optKeyIDs []gomatrixserverlib.KeyID, ) (sks []gomatrixserverlib.ServerKeys, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index cff1ade6f2..e8d1d37336 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // defaultExpiry for EDUs if not listed below @@ -32,8 +33,8 @@ var defaultExpiry = time.Hour * 24 // defaultExpireEDUTypes contains EDUs which can/should be expired after a given time // if the target server isn't reachable for some reason. var defaultExpireEDUTypes = map[string]time.Duration{ - gomatrixserverlib.MTyping: time.Minute, - gomatrixserverlib.MPresence: time.Minute * 10, + spec.MTyping: time.Minute, + spec.MPresence: time.Minute * 10, } // AssociateEDUWithDestination creates an association that the @@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{ // to which servers. func (d *Database) AssociateEDUWithDestinations( ctx context.Context, - destinations map[gomatrixserverlib.ServerName]struct{}, + destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, @@ -49,14 +50,14 @@ func (d *Database) AssociateEDUWithDestinations( if expireEDUTypes == nil { expireEDUTypes = defaultExpireEDUTypes } - expiresAt := gomatrixserverlib.AsTimestamp(time.Now().Add(defaultExpiry)) + expiresAt := spec.AsTimestamp(time.Now().Add(defaultExpiry)) if duration, ok := expireEDUTypes[eduType]; ok { // Keep EDUs for at least x minutes before deleting them - expiresAt = gomatrixserverlib.AsTimestamp(time.Now().Add(duration)) + expiresAt = spec.AsTimestamp(time.Now().Add(duration)) } // We forcibly set m.direct_to_device and m.device_list_update events // to 0, as we always want them to be delivered. (required for E2EE) - if eduType == gomatrixserverlib.MDirectToDevice || eduType == gomatrixserverlib.MDeviceListUpdate { + if eduType == spec.MDirectToDevice || eduType == spec.MDeviceListUpdate { expiresAt = 0 } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -79,7 +80,7 @@ func (d *Database) AssociateEDUWithDestinations( // the next pending transaction, up to the limit specified. func (d *Database) GetPendingEDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ( edus map[*receipt.Receipt]*gomatrixserverlib.EDU, @@ -126,7 +127,7 @@ func (d *Database) GetPendingEDUs( // transaction was sent successfully. func (d *Database) CleanEDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { @@ -169,7 +170,7 @@ func (d *Database) CleanEDUs( // waiting to be sent. func (d *Database) GetPendingEDUServerNames( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil) } @@ -177,7 +178,7 @@ func (d *Database) GetPendingEDUServerNames( func (d *Database) DeleteExpiredEDUs(ctx context.Context) error { var jsonNIDs []int64 err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) (err error) { - expiredBefore := gomatrixserverlib.AsTimestamp(time.Now()) + expiredBefore := spec.AsTimestamp(time.Now()) jsonNIDs, err = d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore) if err != nil { return err diff --git a/federationapi/storage/shared/storage_keys.go b/federationapi/storage/shared/storage_keys.go index 3222b12240..580cf1d847 100644 --- a/federationapi/storage/shared/storage_keys.go +++ b/federationapi/storage/shared/storage_keys.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // FetcherName implements KeyFetcher @@ -30,7 +31,7 @@ func (d Database) FetcherName() string { // FetchKeys implements gomatrixserverlib.KeyDatabase func (d *Database) FetchKeys( ctx context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { return d.ServerSigningKeys.BulkSelectServerKeys(ctx, nil, requests) } diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index 854e005536..5fabfbf204 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -22,7 +22,8 @@ import ( "fmt" "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) // AssociatePDUWithDestination creates an association that the @@ -30,7 +31,7 @@ import ( // to which servers. func (d *Database) AssociatePDUWithDestinations( ctx context.Context, - destinations map[gomatrixserverlib.ServerName]struct{}, + destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -52,10 +53,10 @@ func (d *Database) AssociatePDUWithDestinations( // the next pending transaction, up to the limit specified. func (d *Database) GetPendingPDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ( - events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, + events map[*receipt.Receipt]*types.HeaderedEvent, err error, ) { // Strictly speaking this doesn't need to be using the writer @@ -63,7 +64,7 @@ func (d *Database) GetPendingPDUs( // a guarantee of transactional isolation, it's actually useful // to know in SQLite mode that nothing else is trying to modify // the database. - events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) + events = make(map[*receipt.Receipt]*types.HeaderedEvent) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit) if err != nil { @@ -86,7 +87,7 @@ func (d *Database) GetPendingPDUs( } for nid, blob := range blobs { - var event gomatrixserverlib.HeaderedEvent + var event types.HeaderedEvent if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } @@ -105,7 +106,7 @@ func (d *Database) GetPendingPDUs( // successfully. func (d *Database) CleanPDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { @@ -148,6 +149,6 @@ func (d *Database) CleanPDUs( // waiting to be sent. func (d *Database) GetPendingPDUServerNames( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { return d.FederationQueuePDUs.SelectQueuePDUServerNames(ctx, nil) } diff --git a/federationapi/storage/sqlite3/assumed_offline_table.go b/federationapi/storage/sqlite3/assumed_offline_table.go index ff2afb4da1..f8de7f0c57 100644 --- a/federationapi/storage/sqlite3/assumed_offline_table.go +++ b/federationapi/storage/sqlite3/assumed_offline_table.go @@ -19,7 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const assumedOfflineSchema = ` @@ -68,7 +68,7 @@ func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err } func (s *assumedOfflineStatements) InsertAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) _, err := stmt.ExecContext(ctx, serverName) @@ -76,7 +76,7 @@ func (s *assumedOfflineStatements) InsertAssumedOffline( } func (s *assumedOfflineStatements) SelectAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (bool, error) { stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) res, err := stmt.QueryContext(ctx, serverName) @@ -91,7 +91,7 @@ func (s *assumedOfflineStatements) SelectAssumedOffline( } func (s *assumedOfflineStatements) DeleteAssumedOffline( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) _, err := stmt.ExecContext(ctx, serverName) diff --git a/federationapi/storage/sqlite3/blacklist_table.go b/federationapi/storage/sqlite3/blacklist_table.go index 5122bff160..2c65c487cb 100644 --- a/federationapi/storage/sqlite3/blacklist_table.go +++ b/federationapi/storage/sqlite3/blacklist_table.go @@ -19,7 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const blacklistSchema = ` @@ -69,7 +69,7 @@ func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { } func (s *blacklistStatements) InsertBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) _, err := stmt.ExecContext(ctx, serverName) @@ -77,7 +77,7 @@ func (s *blacklistStatements) InsertBlacklist( } func (s *blacklistStatements) SelectBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (bool, error) { stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) res, err := stmt.QueryContext(ctx, serverName) @@ -92,7 +92,7 @@ func (s *blacklistStatements) SelectBlacklist( } func (s *blacklistStatements) DeleteBlacklist( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) _, err := stmt.ExecContext(ctx, serverName) diff --git a/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go b/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go index c5030163b9..d8be4695ee 100644 --- a/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go +++ b/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go @@ -20,7 +20,7 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error { @@ -52,7 +52,7 @@ INSERT if err != nil { return fmt.Errorf("failed to update queue_edus: %w", err) } - _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24))) + _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", spec.AsTimestamp(time.Now().Add(time.Hour*24))) if err != nil { return fmt.Errorf("failed to update queue_edus: %w", err) } diff --git a/federationapi/storage/sqlite3/inbound_peeks_table.go b/federationapi/storage/sqlite3/inbound_peeks_table.go index 8c35679340..e58d537781 100644 --- a/federationapi/storage/sqlite3/inbound_peeks_table.go +++ b/federationapi/storage/sqlite3/inbound_peeks_table.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const inboundPeeksSchema = ` @@ -86,7 +86,7 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro } func (s *inboundPeeksStatements) InsertInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt) @@ -95,7 +95,7 @@ func (s *inboundPeeksStatements) InsertInboundPeek( } func (s *inboundPeeksStatements) RenewInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) @@ -103,7 +103,7 @@ func (s *inboundPeeksStatements) RenewInboundPeek( } func (s *inboundPeeksStatements) SelectInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (*types.InboundPeek, error) { row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID) inboundPeek := types.InboundPeek{} @@ -152,7 +152,7 @@ func (s *inboundPeeksStatements) SelectInboundPeeks( } func (s *inboundPeeksStatements) DeleteInboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) return diff --git a/federationapi/storage/sqlite3/joined_hosts_table.go b/federationapi/storage/sqlite3/joined_hosts_table.go index 2f0763829b..2412cacdb7 100644 --- a/federationapi/storage/sqlite3/joined_hosts_table.go +++ b/federationapi/storage/sqlite3/joined_hosts_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const joinedHostsSchema = ` @@ -104,7 +104,7 @@ func (s *joinedHostsStatements) InsertJoinedHosts( ctx context.Context, txn *sql.Tx, roomID, eventID string, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) @@ -146,20 +146,20 @@ func (s *joinedHostsStatements) SelectJoinedHosts( func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var serverName string if err = rows.Scan(&serverName); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(serverName)) + result = append(result, spec.ServerName(serverName)) } return result, rows.Err() @@ -167,7 +167,7 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( func (s *joinedHostsStatements) SelectJoinedHostsForRooms( ctx context.Context, roomIDs []string, excludingBlacklisted bool, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i := range roomIDs { iRoomIDs[i] = roomIDs[i] @@ -183,13 +183,13 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms( } defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var serverName string if err = rows.Scan(&serverName); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(serverName)) + result = append(result, spec.ServerName(serverName)) } return result, rows.Err() @@ -212,7 +212,7 @@ func joinedHostsFromStmt( } result = append(result, types.JoinedHost{ MemberEventID: eventID, - ServerName: gomatrixserverlib.ServerName(serverName), + ServerName: spec.ServerName(serverName), }) } diff --git a/federationapi/storage/sqlite3/notary_server_keys_json_table.go b/federationapi/storage/sqlite3/notary_server_keys_json_table.go index 24875569b8..ad6d1b57fe 100644 --- a/federationapi/storage/sqlite3/notary_server_keys_json_table.go +++ b/federationapi/storage/sqlite3/notary_server_keys_json_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const notaryServerKeysJSONSchema = ` @@ -56,7 +57,7 @@ func NewSQLiteNotaryServerKeysTable(db *sql.DB) (s *notaryServerKeysStatements, } func (s *notaryServerKeysStatements) InsertJSONResponse( - ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp, + ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName spec.ServerName, validUntil spec.Timestamp, ) (tables.NotaryID, error) { var notaryID tables.NotaryID return notaryID, txn.Stmt(s.insertServerKeysJSONStmt).QueryRowContext(ctx, string(keyQueryResponseJSON.Raw), serverName, validUntil).Scan(¬aryID) diff --git a/federationapi/storage/sqlite3/notary_server_keys_metadata_table.go b/federationapi/storage/sqlite3/notary_server_keys_metadata_table.go index 7179eb8d6f..2fd9ef2119 100644 --- a/federationapi/storage/sqlite3/notary_server_keys_metadata_table.go +++ b/federationapi/storage/sqlite3/notary_server_keys_metadata_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const notaryServerKeysMetadataSchema = ` @@ -101,12 +102,12 @@ func NewSQLiteNotaryServerKeysMetadataTable(db *sql.DB) (s *notaryServerKeysMeta } func (s *notaryServerKeysMetadataStatements) UpsertKey( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil gomatrixserverlib.Timestamp, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil spec.Timestamp, ) (tables.NotaryID, error) { notaryID := newNotaryID // see if the existing notary ID a) exists, b) has a longer valid_until var existingNotaryID tables.NotaryID - var existingValidUntil gomatrixserverlib.Timestamp + var existingValidUntil spec.Timestamp if err := txn.Stmt(s.selectNotaryKeyMetadataStmt).QueryRowContext(ctx, serverName, keyID).Scan(&existingNotaryID, &existingValidUntil); err != nil { if err != sql.ErrNoRows { return 0, err @@ -121,7 +122,7 @@ func (s *notaryServerKeysMetadataStatements) UpsertKey( return notaryID, err } -func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { +func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { var rows *sql.Rows var err error if len(keyIDs) == 0 { diff --git a/federationapi/storage/sqlite3/outbound_peeks_table.go b/federationapi/storage/sqlite3/outbound_peeks_table.go index 33f452b688..b6684e9b3b 100644 --- a/federationapi/storage/sqlite3/outbound_peeks_table.go +++ b/federationapi/storage/sqlite3/outbound_peeks_table.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const outboundPeeksSchema = ` @@ -85,7 +85,7 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er } func (s *outboundPeeksStatements) InsertOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt) @@ -94,7 +94,7 @@ func (s *outboundPeeksStatements) InsertOutboundPeek( } func (s *outboundPeeksStatements) RenewOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) @@ -102,7 +102,7 @@ func (s *outboundPeeksStatements) RenewOutboundPeek( } func (s *outboundPeeksStatements) SelectOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (*types.OutboundPeek, error) { row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) outboundPeek := types.OutboundPeek{} @@ -151,7 +151,7 @@ func (s *outboundPeeksStatements) SelectOutboundPeeks( } func (s *outboundPeeksStatements) DeleteOutboundPeek( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) return diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go index 0dc9143286..f500a63179 100644 --- a/federationapi/storage/sqlite3/queue_edus_table.go +++ b/federationapi/storage/sqlite3/queue_edus_table.go @@ -20,11 +20,10 @@ import ( "fmt" "strings" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" ) const queueEDUsSchema = ` @@ -121,9 +120,9 @@ func (s *queueEDUsStatements) InsertQueueEDU( ctx context.Context, txn *sql.Tx, eduType string, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, - expiresAt gomatrixserverlib.Timestamp, + expiresAt spec.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) _, err := stmt.ExecContext( @@ -138,7 +137,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( func (s *queueEDUsStatements) DeleteQueueEDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) @@ -160,7 +159,7 @@ func (s *queueEDUsStatements) DeleteQueueEDUs( func (s *queueEDUsStatements) SelectQueueEDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) @@ -194,16 +193,16 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName if err = rows.Scan(&serverName); err != nil { return nil, err } @@ -215,7 +214,7 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( func (s *queueEDUsStatements) SelectExpiredEDUs( ctx context.Context, txn *sql.Tx, - expiredBefore gomatrixserverlib.Timestamp, + expiredBefore spec.Timestamp, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) rows, err := stmt.QueryContext(ctx, expiredBefore) @@ -236,7 +235,7 @@ func (s *queueEDUsStatements) SelectExpiredEDUs( func (s *queueEDUsStatements) DeleteExpiredEDUs( ctx context.Context, txn *sql.Tx, - expiredBefore gomatrixserverlib.Timestamp, + expiredBefore spec.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt) _, err := stmt.ExecContext(ctx, expiredBefore) diff --git a/federationapi/storage/sqlite3/queue_pdus_table.go b/federationapi/storage/sqlite3/queue_pdus_table.go index d8d99f0c08..92075ff903 100644 --- a/federationapi/storage/sqlite3/queue_pdus_table.go +++ b/federationapi/storage/sqlite3/queue_pdus_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const queuePDUsSchema = ` @@ -100,7 +101,7 @@ func (s *queuePDUsStatements) InsertQueuePDU( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) @@ -115,7 +116,7 @@ func (s *queuePDUsStatements) InsertQueuePDU( func (s *queuePDUsStatements) DeleteQueuePDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) @@ -136,7 +137,7 @@ func (s *queuePDUsStatements) DeleteQueuePDUs( } func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (gomatrixserverlib.TransactionID, error) { var transactionID gomatrixserverlib.TransactionID stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) @@ -161,7 +162,7 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt) @@ -184,16 +185,16 @@ func (s *queuePDUsStatements) SelectQueuePDUs( func (s *queuePDUsStatements) SelectQueuePDUServerNames( ctx context.Context, txn *sql.Tx, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName if err = rows.Scan(&serverName); err != nil { return nil, err } diff --git a/federationapi/storage/sqlite3/relay_servers_table.go b/federationapi/storage/sqlite3/relay_servers_table.go index 27c3cca2ce..36cabeb4d9 100644 --- a/federationapi/storage/sqlite3/relay_servers_table.go +++ b/federationapi/storage/sqlite3/relay_servers_table.go @@ -21,7 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const relayServersSchema = ` @@ -77,8 +77,8 @@ func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err erro func (s *relayServersStatements) InsertRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { for _, relayServer := range relayServers { stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) @@ -92,8 +92,8 @@ func (s *relayServersStatements) InsertRelayServers( func (s *relayServersStatements) SelectRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, -) ([]gomatrixserverlib.ServerName, error) { + serverName spec.ServerName, +) ([]spec.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) rows, err := stmt.QueryContext(ctx, serverName) if err != nil { @@ -101,13 +101,13 @@ func (s *relayServersStatements) SelectRelayServers( } defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") - var result []gomatrixserverlib.ServerName + var result []spec.ServerName for rows.Next() { var relayServer string if err = rows.Scan(&relayServer); err != nil { return nil, err } - result = append(result, gomatrixserverlib.ServerName(relayServer)) + result = append(result, spec.ServerName(relayServer)) } return result, nil } @@ -115,8 +115,8 @@ func (s *relayServersStatements) SelectRelayServers( func (s *relayServersStatements) DeleteRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1) deleteStmt, err := s.db.Prepare(deleteSQL) @@ -138,7 +138,7 @@ func (s *relayServersStatements) DeleteRelayServers( func (s *relayServersStatements) DeleteAllRelayServers( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) if _, err := stmt.ExecContext(ctx, serverName); err != nil { diff --git a/federationapi/storage/sqlite3/server_key_table.go b/federationapi/storage/sqlite3/server_key_table.go index b32ff0926a..f28b899405 100644 --- a/federationapi/storage/sqlite3/server_key_table.go +++ b/federationapi/storage/sqlite3/server_key_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const serverSigningKeysSchema = ` @@ -82,7 +83,7 @@ func NewSQLiteServerSigningKeysTable(db *sql.DB) (s *serverSigningKeyStatements, func (s *serverSigningKeyStatements) BulkSelectServerKeys( ctx context.Context, txn *sql.Tx, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, + requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { nameAndKeyIDs := make([]string, 0, len(requests)) for request := range requests { @@ -107,7 +108,7 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys( return fmt.Errorf("bulkSelectServerKeys: %v", err) } r := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: gomatrixserverlib.ServerName(serverName), + ServerName: spec.ServerName(serverName), KeyID: gomatrixserverlib.KeyID(keyID), } vk := gomatrixserverlib.VerifyKey{} @@ -117,8 +118,8 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys( } results[r] = gomatrixserverlib.PublicKeyLookupResult{ VerifyKey: vk, - ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), - ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), + ValidUntilTS: spec.Timestamp(validUntilTS), + ExpiredTS: spec.Timestamp(expiredTS), } } return nil diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index c64c9a4f02..00c8afa059 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // Database stores information needed by the federation sender @@ -34,7 +34,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (*Database, error) { var d Database var err error if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { diff --git a/federationapi/storage/storage.go b/federationapi/storage/storage.go index 4eb9d2c984..322a6c75bb 100644 --- a/federationapi/storage/storage.go +++ b/federationapi/storage/storage.go @@ -26,11 +26,11 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // NewDatabase opens a new database -func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName) diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 74863c07c3..db71f2c13c 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -12,6 +12,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" ) @@ -23,7 +24,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewDatabase(ctx, cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, caches, func(server gomatrixserverlib.ServerName) bool { return server == "localhost" }) + }, caches, func(server spec.ServerName) bool { return server == "localhost" }) if err != nil { t.Fatalf("NewDatabase returned %s", err) } @@ -34,11 +35,11 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat func TestExpireEDUs(t *testing.T) { var expireEDUTypes = map[string]time.Duration{ - gomatrixserverlib.MReceipt: 0, + spec.MReceipt: 0, } ctx := context.Background() - destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}} + destinations := map[spec.ServerName]struct{}{"localhost": {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateFederationDatabase(t, dbType) defer close() @@ -47,7 +48,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err := db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, spec.MReceipt, expireEDUTypes) assert.NoError(t, err) } // add data without expiry @@ -71,7 +72,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err = db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, spec.MDirectToDevice, expireEDUTypes) assert.NoError(t, err) err = db.DeleteExpiredEDUs(ctx) @@ -249,8 +250,8 @@ func TestInboundPeeking(t *testing.T) { } func TestServersAssumedOffline(t *testing.T) { - server1 := gomatrixserverlib.ServerName("server1") - server2 := gomatrixserverlib.ServerName("server2") + server1 := spec.ServerName("server1") + server2 := spec.ServerName("server2") test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, closeDB := mustCreateFederationDatabase(t, dbType) @@ -305,29 +306,29 @@ func TestServersAssumedOffline(t *testing.T) { } func TestRelayServersStored(t *testing.T) { - server := gomatrixserverlib.ServerName("server") - relayServer1 := gomatrixserverlib.ServerName("relayserver1") - relayServer2 := gomatrixserverlib.ServerName("relayserver2") + server := spec.ServerName("server") + relayServer1 := spec.ServerName("relayserver1") + relayServer2 := spec.ServerName("relayserver2") test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, closeDB := mustCreateFederationDatabase(t, dbType) defer closeDB() - err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + err := db.P2PAddRelayServersForServer(context.Background(), server, []spec.ServerName{relayServer1}) assert.Nil(t, err) relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Equal(t, relayServer1, relayServers[0]) - err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + err = db.P2PRemoveRelayServersForServer(context.Background(), server, []spec.ServerName{relayServer1}) assert.Nil(t, err) relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Zero(t, len(relayServers)) - err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2}) + err = db.P2PAddRelayServersForServer(context.Background(), server, []spec.ServerName{relayServer1, relayServer2}) assert.Nil(t, err) relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) diff --git a/federationapi/storage/storage_wasm.go b/federationapi/storage/storage_wasm.go index d1652d7125..e19a45642f 100644 --- a/federationapi/storage/storage_wasm.go +++ b/federationapi/storage/storage_wasm.go @@ -26,7 +26,7 @@ import ( ) // NewDatabase opens a new database -func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(spec.ServerName) bool) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName) diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 762504e45d..f8de42da79 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -20,26 +20,27 @@ import ( "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type NotaryID int64 type FederationQueuePDUs interface { - InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error - DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName spec.ServerName, nid int64) error + DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, jsonNIDs []int64) error SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) - SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) + SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, limit int) ([]int64, error) + SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]spec.ServerName, error) } type FederationQueueEDUs interface { - InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64, expiresAt gomatrixserverlib.Timestamp) error - DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error - SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName spec.ServerName, nid int64, expiresAt spec.Timestamp) error + DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, jsonNIDs []int64) error + SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, limit int) ([]int64, error) SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) - SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error) - DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error + SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]spec.ServerName, error) + SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore spec.Timestamp) ([]int64, error) + DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore spec.Timestamp) error Prepare() error } @@ -50,10 +51,10 @@ type FederationQueueJSON interface { } type FederationQueueTransactions interface { - InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error - DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error - SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) - SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) + InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName spec.ServerName, nid int64) error + DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, jsonNIDs []int64) error + SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, limit int) ([]int64, error) + SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (int64, error) } type FederationTransactionJSON interface { @@ -63,51 +64,51 @@ type FederationTransactionJSON interface { } type FederationJoinedHosts interface { - InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error + InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName spec.ServerName) error DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error DeleteJoinedHostsForRoom(ctx context.Context, txn *sql.Tx, roomID string) error SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) - SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error) + SelectAllJoinedHosts(ctx context.Context) ([]spec.ServerName, error) + SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]spec.ServerName, error) } type FederationBlacklist interface { - InsertBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error - SelectBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) - DeleteBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + InsertBlacklist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error + SelectBlacklist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error) + DeleteBlacklist(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } type FederationAssumedOffline interface { - InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error - SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) - DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error + SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (bool, error) + DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error } type FederationRelayServers interface { - InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error - SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) - DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error - DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, relayServers []spec.ServerName) error + SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) ([]spec.ServerName, error) + DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, relayServers []spec.ServerName) error + DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) error } type FederationOutboundPeeks interface { - InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) - RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) - SelectOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (outboundPeek *types.OutboundPeek, err error) + InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) (err error) + RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) (err error) + SelectOutboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string) (outboundPeek *types.OutboundPeek, err error) SelectOutboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (outboundPeeks []types.OutboundPeek, err error) - DeleteOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (err error) + DeleteOutboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string) (err error) DeleteOutboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (err error) } type FederationInboundPeeks interface { - InsertInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) - RenewInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) - SelectInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (inboundPeek *types.InboundPeek, err error) + InsertInboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) (err error) + RenewInboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) (err error) + SelectInboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string) (inboundPeek *types.InboundPeek, err error) SelectInboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (inboundPeeks []types.InboundPeek, err error) - DeleteInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (err error) + DeleteInboundPeek(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, roomID, peekID string) (err error) DeleteInboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (err error) } @@ -118,22 +119,22 @@ type FederationNotaryServerKeysJSON interface { // "Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid. // This is to avoid a situation where an attacker publishes a key which is valid for a significant amount of time // without a way for the homeserver owner to revoke it."" - InsertJSONResponse(ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp) (NotaryID, error) + InsertJSONResponse(ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName spec.ServerName, validUntil spec.Timestamp) (NotaryID, error) } // FederationNotaryServerKeysMetadata persists the metadata for FederationNotaryServerKeysJSON type FederationNotaryServerKeysMetadata interface { // UpsertKey updates or inserts a (server_name, key_id) tuple, pointing it via NotaryID at the the response which has the longest valid_until_ts // `newNotaryID` and `newValidUntil` should be the notary ID / valid_until which has this (server_name, key_id) tuple already, e.g one you just inserted. - UpsertKey(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID NotaryID, newValidUntil gomatrixserverlib.Timestamp) (NotaryID, error) + UpsertKey(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID NotaryID, newValidUntil spec.Timestamp) (NotaryID, error) // SelectKeys returns the signed JSON objects which contain the given key IDs. This will be at most the length of `keyIDs` and at least 1 (assuming // the keys exist in the first place). If `keyIDs` is empty, the signed JSON object with the longest valid_until_ts will be returned. - SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) + SelectKeys(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) // DeleteOldJSONResponses removes all responses which are not referenced in FederationNotaryServerKeysMetadata DeleteOldJSONResponses(ctx context.Context, txn *sql.Tx) error } type FederationServerSigningKeys interface { - BulkSelectServerKeys(ctx context.Context, txn *sql.Tx, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) + BulkSelectServerKeys(ctx context.Context, txn *sql.Tx, requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) UpsertServerKeys(ctx context.Context, txn *sql.Tx, request gomatrixserverlib.PublicKeyLookupRequest, key gomatrixserverlib.PublicKeyLookupResult) error } diff --git a/federationapi/storage/tables/relay_servers_table_test.go b/federationapi/storage/tables/relay_servers_table_test.go index b41211551d..6a14e3f16d 100644 --- a/federationapi/storage/tables/relay_servers_table_test.go +++ b/federationapi/storage/tables/relay_servers_table_test.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -57,7 +57,7 @@ func mustCreateRelayServersTable( return database, close } -func Equal(a, b []gomatrixserverlib.ServerName) bool { +func Equal(a, b []spec.ServerName) bool { if len(a) != len(b) { return false } @@ -74,7 +74,7 @@ func TestShouldInsertRelayServers(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRelayServersTable(t, dbType) defer close() - expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + expectedRelayServers := []spec.ServerName{server2, server3} err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) if err != nil { @@ -97,8 +97,8 @@ func TestShouldInsertRelayServersWithDuplicates(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRelayServersTable(t, dbType) defer close() - insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2} - expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + insertRelayServers := []spec.ServerName{server2, server2, server2, server3, server2} + expectedRelayServers := []spec.ServerName{server2, server3} err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) if err != nil { @@ -134,8 +134,8 @@ func TestShouldGetRelayServersUnknownDestination(t *testing.T) { t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) } - if !Equal(relayServers, []gomatrixserverlib.ServerName{}) { - t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers) + if !Equal(relayServers, []spec.ServerName{}) { + t.Fatalf("Expected: %v \nActual: %v", []spec.ServerName{}, relayServers) } }) } @@ -145,8 +145,8 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRelayServersTable(t, dbType) defer close() - relayServers1 := []gomatrixserverlib.ServerName{server2, server3} - relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4} + relayServers1 := []spec.ServerName{server2, server3} + relayServers2 := []spec.ServerName{server1, server3, server4} err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1) if err != nil { @@ -157,16 +157,16 @@ func TestShouldDeleteCorrectRelayServers(t *testing.T) { t.Fatalf("Failed inserting transaction: %s", err.Error()) } - err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2}) + err = db.Table.DeleteRelayServers(ctx, nil, server1, []spec.ServerName{server2}) if err != nil { t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) } - err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4}) + err = db.Table.DeleteRelayServers(ctx, nil, server2, []spec.ServerName{server1, server4}) if err != nil { t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error()) } - expectedRelayServers := []gomatrixserverlib.ServerName{server3} + expectedRelayServers := []spec.ServerName{server3} relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) if err != nil { t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) @@ -189,7 +189,7 @@ func TestShouldDeleteAllRelayServers(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRelayServersTable(t, dbType) defer close() - expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + expectedRelayServers := []spec.ServerName{server2, server3} err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) if err != nil { @@ -205,7 +205,7 @@ func TestShouldDeleteAllRelayServers(t *testing.T) { t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) } - expectedRelayServers1 := []gomatrixserverlib.ServerName{} + expectedRelayServers1 := []spec.ServerName{} relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) if err != nil { t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) diff --git a/federationapi/types/types.go b/federationapi/types/types.go index 5821000cc9..20f92e804b 100644 --- a/federationapi/types/types.go +++ b/federationapi/types/types.go @@ -14,9 +14,7 @@ package types -import ( - "github.com/matrix-org/gomatrixserverlib" -) +import "github.com/matrix-org/gomatrixserverlib/spec" const MSigningKeyUpdate = "m.signing_key_update" // TODO: move to gomatrixserverlib @@ -25,10 +23,10 @@ type JoinedHost struct { // The MemberEventID of a m.room.member join event. MemberEventID string // The domain part of the state key of the m.room.member join event - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } -type ServerNames []gomatrixserverlib.ServerName +type ServerNames []spec.ServerName func (s ServerNames) Len() int { return len(s) } func (s ServerNames) Swap(i, j int) { s[i], s[j] = s[j], s[i] } @@ -38,7 +36,7 @@ func (s ServerNames) Less(i, j int) bool { return s[i] < s[j] } type OutboundPeek struct { PeekID string RoomID string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName CreationTimestamp int64 RenewedTimestamp int64 RenewalInterval int64 @@ -48,7 +46,7 @@ type OutboundPeek struct { type InboundPeek struct { PeekID string RoomID string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName CreationTimestamp int64 RenewedTimestamp int64 RenewalInterval int64 @@ -64,7 +62,7 @@ type FederationReceiptData struct { } type ReceiptTS struct { - TS gomatrixserverlib.Timestamp `json:"ts"` + TS spec.Timestamp `json:"ts"` } type Presence struct { diff --git a/go.mod b/go.mod index 790d1ee416..c246cd0436 100644 --- a/go.mod +++ b/go.mod @@ -5,58 +5,61 @@ require ( github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/Masterminds/semver/v3 v3.1.1 - github.com/blevesearch/bleve/v2 v2.3.6 + github.com/blevesearch/bleve/v2 v2.3.8 github.com/codeclysm/extract v2.2.0+incompatible github.com/dgraph-io/ristretto v0.1.1 github.com/docker/docker v20.10.24+incompatible github.com/docker/go-connections v0.4.0 - github.com/getsentry/sentry-go v0.14.0 - github.com/go-ldap/ldap/v3 v3.4.4 - github.com/golang-jwt/jwt/v4 v4.4.1 + github.com/getsentry/sentry-go v0.22.0 github.com/gologme/log v1.3.0 github.com/google/go-cmp v0.5.9 github.com/google/uuid v1.3.0 github.com/gorilla/mux v1.8.0 github.com/kardianos/minwinsvc v1.0.2 - github.com/lib/pq v1.10.7 + github.com/lib/pq v1.10.9 github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f + github.com/matrix-org/gomatrixserverlib v0.0.0-20230628151943-f6e3c7f7b093 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 - github.com/mattn/go-sqlite3 v1.14.16 - github.com/nats-io/nats-server/v2 v2.9.15 - github.com/nats-io/nats.go v1.24.0 + github.com/mattn/go-sqlite3 v1.14.17 + github.com/nats-io/nats-server/v2 v2.9.19 + github.com/nats-io/nats.go v1.27.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/opentracing/opentracing-go v1.2.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 - github.com/prometheus/client_golang v1.13.0 - github.com/sirupsen/logrus v1.9.0 - github.com/stretchr/testify v1.8.1 + github.com/prometheus/client_golang v1.16.0 + github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.8.2 github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.6.0 + golang.org/x/crypto v0.10.0 + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db golang.org/x/image v0.5.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e - golang.org/x/term v0.5.0 + golang.org/x/sync v0.2.0 + golang.org/x/term v0.9.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 - modernc.org/sqlite v1.19.3 + maunium.net/go/mautrix v0.15.1 + modernc.org/sqlite v1.23.1 ) require ( github.com/MFAshby/stdemuxerhook v1.0.0 + github.com/go-ldap/ldap/v3 v3.4.5 + github.com/golang-jwt/jwt/v4 v4.5.0 github.com/matryer/is v1.4.0 ) require ( - github.com/Azure/go-ntlmssp v0.0.0-20220621081337-cb9428e4ac1e // indirect + github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect github.com/Microsoft/go-winio v0.5.2 // indirect github.com/RoaringBitmap/roaring v1.2.3 // indirect @@ -76,60 +79,63 @@ require ( github.com/blevesearch/zapx/v12 v12.3.7 // indirect github.com/blevesearch/zapx/v13 v13.3.7 // indirect github.com/blevesearch/zapx/v14 v14.3.7 // indirect - github.com/blevesearch/zapx/v15 v15.3.8 // indirect - github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/blevesearch/zapx/v15 v15.3.10 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/docker/distribution v2.8.1+incompatible // indirect + github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/go-units v0.5.0 // indirect - github.com/dustin/go-humanize v1.0.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 // indirect github.com/golang/glog v1.0.0 // indirect - github.com/golang/protobuf v1.5.2 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/h2non/filetype v1.1.3 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/juju/errors v1.0.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/klauspost/compress v1.16.0 // indirect - github.com/kr/pretty v0.3.1 // indirect - github.com/mattn/go-isatty v0.0.16 // indirect - github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect + github.com/klauspost/compress v1.16.5 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect + github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/minio/highwayhash v1.0.2 // indirect github.com/moby/term v0.0.0-20220808134915-39b0c02b01ae // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/mschoch/smat v0.2.0 // indirect - github.com/nats-io/jwt/v2 v2.3.0 // indirect - github.com/nats-io/nkeys v0.3.0 // indirect + github.com/nats-io/jwt/v2 v2.4.1 // indirect + github.com/nats-io/nkeys v0.4.4 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.2.0 // indirect - github.com/prometheus/common v0.37.0 // indirect - github.com/prometheus/procfs v0.8.0 // indirect - github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa // indirect + github.com/prometheus/client_model v0.3.0 // indirect + github.com/prometheus/common v0.42.0 // indirect + github.com/prometheus/procfs v0.10.1 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect + github.com/rs/zerolog v1.29.1 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect go.etcd.io/bbolt v1.3.6 // indirect - golang.org/x/mod v0.6.0 // indirect - golang.org/x/net v0.7.0 // indirect - golang.org/x/sys v0.5.0 // indirect - golang.org/x/text v0.7.0 // indirect + golang.org/x/mod v0.8.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.9.0 // indirect + golang.org/x/text v0.10.0 // indirect golang.org/x/time v0.3.0 // indirect - golang.org/x/tools v0.2.0 // indirect - google.golang.org/protobuf v1.28.1 // indirect + golang.org/x/tools v0.6.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/uint128 v1.2.0 // indirect + maunium.net/go/maulogger/v2 v2.4.1 // indirect modernc.org/cc/v3 v3.40.0 // indirect - modernc.org/ccgo/v3 v3.16.13-0.20221017192402-261537637ce8 // indirect - modernc.org/libc v1.21.4 // indirect + modernc.org/ccgo/v3 v3.16.13 // indirect + modernc.org/libc v1.22.5 // indirect modernc.org/mathutil v1.5.0 // indirect - modernc.org/memory v1.4.0 // indirect + modernc.org/memory v1.5.0 // indirect modernc.org/opt v0.1.3 // indirect modernc.org/strutil v1.1.3 // indirect modernc.org/token v1.0.1 // indirect diff --git a/go.sum b/go.sum index a0a4c63215..9fc5b9427b 100644 --- a/go.sum +++ b/go.sum @@ -1,35 +1,3 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= -cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= -cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= -cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= -cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= -cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To= -cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= -cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M= -cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc= -cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= -cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= -cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= -cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= -cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= -cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= -cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= -cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= -cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= -cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= -cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= -cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= -cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= -cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= -cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= -cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= -cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= -cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= -cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= -cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= -cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/Arceliar/ironwood v0.0.0-20221025225125-45b4281814c2 h1:Usab30pNT2i/vZvpXcN9uOr5IO1RZPcUqoGH0DIAPnU= github.com/Arceliar/ironwood v0.0.0-20221025225125-45b4281814c2/go.mod h1:RP72rucOFm5udrnEzTmIWLRVGQiV/fSUAQXJ0RST/nk= @@ -37,9 +5,8 @@ github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 h1:WndgpSW13S32VLQ3 github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= -github.com/Azure/go-ntlmssp v0.0.0-20220621081337-cb9428e4ac1e h1:NeAW1fUYUEWhft7pkxDf6WoUvEZJ/uOKsvtpjLnn8MU= -github.com/Azure/go-ntlmssp v0.0.0-20220621081337-cb9428e4ac1e/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= @@ -55,11 +22,8 @@ github.com/RoaringBitmap/roaring v0.4.7/go.mod h1:8khRDP4HmeXns4xIj9oGrKSz7XTQiJ github.com/RoaringBitmap/roaring v1.2.3 h1:yqreLINqIrX22ErkKI0vY47/ivtJr6n+kMhVOVmhWBY= github.com/RoaringBitmap/roaring v1.2.3/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 h1:Kk6a4nehpJ3UuJRqlA3JxYxBZEqCeOmATOvrbT4p9RA= +github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/anacrolix/envpprof v0.0.0-20180404065416-323002cec2fa/go.mod h1:KgHhUaQMc8cC0+cEflSgCFNFbKwi5h54gqtVn8yhP7c= github.com/anacrolix/envpprof v1.0.0/go.mod h1:KgHhUaQMc8cC0+cEflSgCFNFbKwi5h54gqtVn8yhP7c= github.com/anacrolix/envpprof v1.1.1 h1:sHQCyj7HtiSfaZAzL2rJrQdyS7odLqlwO6nhk/tG/j8= @@ -71,15 +35,13 @@ github.com/anacrolix/missinggo v1.2.1 h1:0IE3TqX5y5D0IxeMwTyIgqdDew4QrzcXaaEnJQy github.com/anacrolix/missinggo v1.2.1/go.mod h1:J5cMhif8jPmFoC3+Uvob3OXXNIhOUikzMt+uUjeM21Y= github.com/anacrolix/missinggo/perf v1.0.0/go.mod h1:ljAFWkBuzkO12MQclXzZrosP5urunoLS0Cbvb4V0uMQ= github.com/anacrolix/tagflag v0.0.0-20180109131632-2146c8d41bf0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= github.com/bits-and-blooms/bitset v1.5.0 h1:NpE8frKRLGHIcEzkR+gZhiioW1+WbYV6fKwD6ZIpQT8= github.com/bits-and-blooms/bitset v1.5.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= -github.com/blevesearch/bleve/v2 v2.3.6 h1:NlntUHcV5CSWIhpugx4d/BRMGCiaoI8ZZXrXlahzNq4= -github.com/blevesearch/bleve/v2 v2.3.6/go.mod h1:JM2legf1cKVkdV8Ehu7msKIOKC0McSw0Q16Fmv9vsW4= +github.com/blevesearch/bleve/v2 v2.3.8 h1:IqFyMJ73n4gY8AmVqM8Sa6EtAZ5beE8yramVqCvs2kQ= +github.com/blevesearch/bleve/v2 v2.3.8/go.mod h1:Lh9aZEHrLKxwPnW4z4lsBEGnflZQ1V/aWP/t+htsiDw= github.com/blevesearch/bleve_index_api v1.0.5 h1:Lc986kpC4Z0/n1g3gg8ul7H+lxgOQPcXb9SxvQGu+tw= github.com/blevesearch/bleve_index_api v1.0.5/go.mod h1:YXMDwaXFFXwncRS8UobWs7nvo0DmusriM1nztTlj1ms= github.com/blevesearch/geo v0.1.17 h1:AguzI6/5mHXapzB0gE9IKWo+wWPHZmXZoscHcjFgAFA= @@ -108,23 +70,18 @@ github.com/blevesearch/zapx/v13 v13.3.7 h1:igIQg5eKmjw168I7av0Vtwedf7kHnQro/M+ub github.com/blevesearch/zapx/v13 v13.3.7/go.mod h1:yyrB4kJ0OT75UPZwT/zS+Ru0/jYKorCOOSY5dBzAy+s= github.com/blevesearch/zapx/v14 v14.3.7 h1:gfe+fbWslDWP/evHLtp/GOvmNM3sw1BbqD7LhycBX20= github.com/blevesearch/zapx/v14 v14.3.7/go.mod h1:9J/RbOkqZ1KSjmkOes03AkETX7hrXT0sFMpWH4ewC4w= -github.com/blevesearch/zapx/v15 v15.3.8 h1:q4uMngBHzL1IIhRc8AJUEkj6dGOE3u1l3phLu7hq8uk= -github.com/blevesearch/zapx/v15 v15.3.8/go.mod h1:m7Y6m8soYUvS7MjN9eKlz1xrLCcmqfFadmu7GhWIrLY= +github.com/blevesearch/zapx/v15 v15.3.10 h1:bQ9ZxJCj6rKp873EuVJu2JPxQ+EWQZI1cjJGeroovaQ= +github.com/blevesearch/zapx/v15 v15.3.10/go.mod h1:m7Y6m8soYUvS7MjN9eKlz1xrLCcmqfFadmu7GhWIrLY= github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= github.com/bradfitz/iter v0.0.0-20190303215204-33e6a9893b0c/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8 h1:GKTyiRCL6zVf5wWaqKnf+7Qs6GbEPfd4iMOitWzXJx8= github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8/go.mod h1:spo1JLcs67NmW1aVLEgtA8Yy1elc+X8y5SRW1sFW4Og= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= -github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/codeclysm/extract v2.2.0+incompatible h1:q3wyckoA30bhUSiwdQezMqVhwd8+WGE64/GL//LtUhI= github.com/codeclysm/extract v2.2.0+incompatible/go.mod h1:2nhFMPHiU9At61hz+12bfrlpXSUrOnK+wR+KlGO4Uks= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -134,8 +91,8 @@ github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWa github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= -github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68= -github.com/docker/distribution v2.8.1+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= +github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= +github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/docker v20.10.24+incompatible h1:Ugvxm7a8+Gz6vqQYQQ2W7GYq5EUPaAiuPgIfVyI3dYE= github.com/docker/docker v20.10.24+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= @@ -144,132 +101,68 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dustin/go-humanize v0.0.0-20180421182945-02af3965c54e/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= -github.com/getsentry/sentry-go v0.14.0 h1:rlOBkuFZRKKdUnKO+0U3JclRDQKlRu5vVQtkWSQvC70= -github.com/getsentry/sentry-go v0.14.0/go.mod h1:RZPJKSw+adu8PBNygiri/A98FqVr2HtRckJk9XVxJ9I= +github.com/getsentry/sentry-go v0.22.0 h1:XNX9zKbv7baSEI65l+H1GEJgSeIC1c7EN5kluWaP6dM= +github.com/getsentry/sentry-go v0.22.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd/go.mod h1:/20jfyN9Y5QPEAprSgKAUr+glWDY39ZiUEAYOEv5dsE= github.com/glycerine/goconvey v0.0.0-20180728074245-46e3a41ad493/go.mod h1:Ogl1Tioa0aV7gstGFO7KhffUsb9M4ydbEbbxpcEDc24= github.com/go-asn1-ber/asn1-ber v1.5.4 h1:vXT6d/FNDiELJnLb6hGNa309LMsrCoYFvpwHDF0+Y1A= github.com/go-asn1-ber/asn1-ber v1.5.4/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-kit/log v0.2.0/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= -github.com/go-ldap/ldap/v3 v3.4.4 h1:qPjipEpt+qDa6SI/h1fzuGWoRUY+qqQ9sOZq67/PYUs= -github.com/go-ldap/ldap/v3 v3.4.4/go.mod h1:fe1MsuN5eJJ1FeLT/LEBVdWfNWKh459R7aXgXtJC+aI= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/go-ldap/ldap/v3 v3.4.5 h1:ekEKmaDrpvR2yf5Nc/DClsGG9lAmdDixe44mLzlW5r8= +github.com/go-ldap/ldap/v3 v3.4.5/go.mod h1:bMGIq3AGbytbaMwf8wdv5Phdxz0FWHTIYMSzyrYgnQs= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v4 v4.4.1 h1:pC5DB52sCeK48Wlb9oPcdhnjkz1TKt1D/P7WKJ0kUcQ= -github.com/golang-jwt/jwt/v4 v4.4.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 h1:gtexQ/VGyN+VVFRXSFiguSNcXmS6rkKT+X7FdIrTtfo= github.com/golang/geo v0.0.0-20210211234256-740aa86cb551/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0 h1:nfP3RFugxnNRyKgeWd4oI1nYvXpxrx8ck8ZrcizshdQ= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= -github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4= github.com/google/btree v0.0.0-20180124185431-e89373fe6b4a/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= -github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/huandu/xstrings v1.0.0 h1:pO2K/gKgKaat5LdpAhxhluX2GPQMaI3W5FUz/I/UnWk= github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo= -github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= -github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/juju/errors v1.0.0 h1:yiq7kjCLll1BiaRuNY53MGI0+EQ3rF6GB+wvboZDefM= github.com/juju/errors v1.0.0/go.mod h1:B5x9thDqx0wIMH3+aLIMP9HjItInYWObRovoCFM5Qe8= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kardianos/minwinsvc v1.0.2 h1:JmZKFJQrmTGa/WiW+vkJXKmfzdjabuEW4Tirj5lLdR0= github.com/kardianos/minwinsvc v1.0.2/go.mod h1:LUZNYhNmxujx2tR7FbdxqYJ9XDDoCd3MQcl1o//FWl4= @@ -277,37 +170,37 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.16.0 h1:iULayQNOReoYUe+1qtKOqw9CwJv3aNQu8ivo7lw1HU4= -github.com/klauspost/compress v1.16.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= +github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= -github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM= github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f h1:D7IgZA2DxBroqCTxo2uXEmjj8eCI1OzqqKRE4SAgmBU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230405171344-5f597d85ba4f/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230628151943-f6e3c7f7b093 h1:FHd3SYhU2ZxZhkssZ/7ms5+M2j+g94lYp8ztvA1E6tA= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230628151943-f6e3c7f7b093/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4= github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= -github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= -github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= -github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= +github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= @@ -316,8 +209,6 @@ github.com/moby/term v0.0.0-20220808134915-39b0c02b01ae/go.mod h1:E2VnQOmVuvZB6U github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= @@ -325,16 +216,14 @@ github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7P github.com/mschoch/smat v0.0.0-20160514031455-90eadee771ae/go.mod h1:qAyveg+e4CE+eKJXWVjKXM4ck2QobLqTDytGJbLLhJg= github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM= github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= -github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.15 h1:MuwEJheIwpvFgqvbs20W8Ish2azcygjf4Z0liVu2I4c= -github.com/nats-io/nats-server/v2 v2.9.15/go.mod h1:QlCTy115fqpx4KSOPFIxSV7DdI6OxtZsGOL1JLdeRlE= -github.com/nats-io/nats.go v1.24.0 h1:CRiD8L5GOQu/DcfkmgBcTTIQORMwizF+rPk6T0RaHVQ= -github.com/nats-io/nats.go v1.24.0/go.mod h1:dVQF+BK3SzUZpwyzHedXsvH3EO38aVKuOPkkHlv5hXA= -github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= -github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= +github.com/nats-io/jwt/v2 v2.4.1 h1:Y35W1dgbbz2SQUYDPCaclXcuqleVmpbRa7646Jf2EX4= +github.com/nats-io/jwt/v2 v2.4.1/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI= +github.com/nats-io/nats-server/v2 v2.9.19 h1:OF9jSKZGo425C/FcVVIvNgpd36CUe7aVTTXEZRJk6kA= +github.com/nats-io/nats-server/v2 v2.9.19/go.mod h1:aTb/xtLCGKhfTFLxP591CMWfkdgBmcUUSkiSOe5A3gw= +github.com/nats-io/nats.go v1.27.0 h1:3o9fsPhmoKm+yK7rekH2GtWoE+D9jFbw8N3/ayI1C00= +github.com/nats-io/nats.go v1.27.0/go.mod h1:XpbWUlOElGwTYbMR7imivs7jJj9GtK7ypv321Wp6pjc= +github.com/nats-io/nkeys v0.4.4 h1:xvBJ8d69TznjcQl9t6//Q5xXuVhyYiSos6RPtvQNTwA= +github.com/nats-io/nkeys v0.4.4/go.mod h1:XUkxdLPTufzlihbamfzQ7mw/VGx6ObUs+0bN5sNvt64= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ= @@ -352,56 +241,35 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= -github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= -github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= -github.com/prometheus/client_golang v1.13.0 h1:b71QUfeo5M8gq2+evJdTPfZhYMAU0uKPkyPJ7TPsloU= -github.com/prometheus/client_golang v1.13.0/go.mod h1:vTeo+zgvILHsnnj/39Ou/1fPN5nJFOEMgftOUOmlvYQ= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= -github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= -github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= -github.com/prometheus/common v0.37.0 h1:ccBbHCgIiT9uSoFY0vX8H3zsNR5eLt17/RQLUvn8pXE= -github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= -github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= -github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= +github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8= +github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc= +github.com/prometheus/client_model v0.3.0 h1:UBgGFHqYdG/TPFD1B1ogZywDqEkwp3fBMvqdiQ7Xew4= +github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= +github.com/prometheus/common v0.42.0 h1:EKsfXEYo4JpWMHH5cg+KOUWeuJSov1Id8zGR8eeI1YM= +github.com/prometheus/common v0.42.0/go.mod h1:xBwqVerjNdUDjgODMpudtOMwlOwf2SaTr1yjz4b7Zbc= +github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg= +github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa h1:tEkEyxYeZ43TR55QU/hsIt9aRGBxbgGuz9CGykjvogY= -github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= +github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -411,10 +279,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -433,247 +300,113 @@ github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6 github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= github.com/yggdrasil-network/yggdrasil-go v0.4.6 h1:GALUDV9QPz/5FVkbazpkTc9EABHufA556JwUJZr41j4= github.com/yggdrasil-network/yggdrasil-go v0.4.6/go.mod h1:PBMoAOvQjA9geNEeGyMXA9QgCS6Bu+9V+1VkWM84wpw= -github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU= go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= -go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= -go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= -go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= -golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= -golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= -golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI= golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= -golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e h1:zSgtO19fpg781xknwqiQPmOHaASr6E7ZVlTseLd9Fx4= golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e/go.mod h1:aAjjkJNdrh3PMckS4B10TGS2nag27cbKR1y2BpUxsiY= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= -golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= +golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28= +golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= +golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190624222133-a101b041ded4/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= -golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= -golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= -golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= +golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -682,100 +415,20 @@ gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJ gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= -google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= -google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= -google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= -google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= -google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= -google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= -google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= -google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA= -google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= -google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= -google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= -google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= -google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/h2non/bimg.v1 v1.1.9 h1:wZIUbeOnwr37Ta4aofhIv8OI8v4ujpjXC9mXnAGpQjM= gopkg.in/h2non/bimg.v1 v1.1.9/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So= gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY= gopkg.in/macaroon.v2 v2.1.0 h1:HZcsjBCzq9t0eBPMKqTN/uSN6JOm78ZJ2INbqcBQOUI= gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= @@ -784,38 +437,32 @@ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= +maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= +maunium.net/go/mautrix v0.15.1 h1:pmCtMjYRpd83+2UL+KTRFYQo5to0373yulimvLK+1k0= +maunium.net/go/mautrix v0.15.1/go.mod h1:icQIrvz2NldkRLTuzSGzmaeuMUmw+fzO7UVycPeauN8= modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= -modernc.org/ccgo/v3 v3.16.13-0.20221017192402-261537637ce8 h1:0+dsXf0zeLx9ixj4nilg6jKe5Bg1ilzBwSFq4kJmIUc= -modernc.org/ccgo/v3 v3.16.13-0.20221017192402-261537637ce8/go.mod h1:fUB3Vn0nVPReA+7IG7yZDfjv1TMWjhQP8gCxrFAtL5g= +modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw= +modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY= modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= -modernc.org/libc v1.21.4 h1:CzTlumWeIbPV5/HVIMzYHNPCRP8uiU/CWiN2gtd/Qu8= -modernc.org/libc v1.21.4/go.mod h1:przBsL5RDOZajTVslkugzLBj1evTue36jEomFQOoYuI= +modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE= +modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY= modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= -modernc.org/memory v1.4.0 h1:crykUfNSnMAXaOJnnxcSzbUGMqkLWjklJKkBK2nwZwk= -modernc.org/memory v1.4.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= +modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= -modernc.org/sqlite v1.19.3 h1:dIoagx6yIQT3V/zOSeAyZ8OqQyEr17YTgETOXTZNJMA= -modernc.org/sqlite v1.19.3/go.mod h1:xiyJD7FY8mTZXnQwE/gEL1STtFrrnDx03V8KhVQmcr8= +modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM= +modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk= modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY= modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw= -modernc.org/tcl v1.15.0 h1:oY+JeD11qVVSgVvodMJsu7Edf8tr5E/7tuhF5cNYz34= +modernc.org/tcl v1.15.2 h1:C4ybAYCGJw968e+Me18oW55kD/FexcHbqH2xak1ROSY= modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg= modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= -modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE= -rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= +modernc.org/z v1.7.3 h1:zDJf6iHjrnB+WRD88stbXokugjyc0/pB91ri1gO6LZY= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= -rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= -rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/helm/cr.yaml b/helm/cr.yaml index 884c2b46bf..d39e8bdce2 100644 --- a/helm/cr.yaml +++ b/helm/cr.yaml @@ -1,2 +1,3 @@ release-name-template: "helm-{{ .Name }}-{{ .Version }}" -pages-index-path: docs/index.yaml \ No newline at end of file +pages-index-path: docs/index.yaml +make-release-latest: false \ No newline at end of file diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml index 6a428e00f6..668fd84ec2 100644 --- a/helm/dendrite/Chart.yaml +++ b/helm/dendrite/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v2 name: dendrite -version: "0.12.2" -appVersion: "0.12.0" +version: "0.13.0" +appVersion: "0.13.0" description: Dendrite Matrix Homeserver type: application keywords: diff --git a/helm/dendrite/README.md b/helm/dendrite/README.md index ca5705c036..562d1e2359 100644 --- a/helm/dendrite/README.md +++ b/helm/dendrite/README.md @@ -1,7 +1,7 @@ # dendrite -![Version: 0.12.2](https://img.shields.io/badge/Version-0.12.2-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.12.0](https://img.shields.io/badge/AppVersion-0.12.0-informational?style=flat-square) +![Version: 0.13.0](https://img.shields.io/badge/Version-0.13.0-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.13.0](https://img.shields.io/badge/AppVersion-0.13.0-informational?style=flat-square) Dendrite Matrix Homeserver Status: **NOT PRODUCTION READY** diff --git a/helm/dendrite/templates/configmap_grafana_dashboards.yaml b/helm/dendrite/templates/configmap_grafana_dashboards.yaml index e2abc4909c..9ab77e3b31 100644 --- a/helm/dendrite/templates/configmap_grafana_dashboards.yaml +++ b/helm/dendrite/templates/configmap_grafana_dashboards.yaml @@ -1,5 +1,5 @@ {{- if .Values.grafana.dashboards.enabled }} -{{- range $path, $bytes := .Files.Glob "grafana_dashboards/*" }} +{{- range $path, $bytes := .Files.Glob "grafana_dashboards/*.json" }} --- apiVersion: v1 kind: ConfigMap diff --git a/helm/dendrite/templates/ingress.yaml b/helm/dendrite/templates/ingress.yaml index 8f86ad7239..9ef413dc98 100644 --- a/helm/dendrite/templates/ingress.yaml +++ b/helm/dendrite/templates/ingress.yaml @@ -1,24 +1,25 @@ {{- if .Values.ingress.enabled -}} - {{- $fullName := include "dendrite.fullname" . -}} - {{- $svcPort := .Values.service.port -}} - {{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }} - {{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }} - {{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}} - {{- end }} - {{- end }} - {{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}} +{{- $fullName := include "dendrite.fullname" . -}} +{{- $serverNameHost := .Values.dendrite_config.global.server_name -}} +{{- $wellKnownServerHost := default $serverNameHost (regexFind "^[^:]+" .Values.dendrite_config.global.well_known_server_name) -}} +{{- $wellKnownClientHost := default $serverNameHost (regexFind "^[^:]+" .Values.dendrite_config.global.well_known_client_name) -}} +{{- $allHosts := list $serverNameHost $wellKnownServerHost $wellKnownClientHost | uniq -}} +{{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}} apiVersion: networking.k8s.io/v1 - {{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}} +{{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}} apiVersion: networking.k8s.io/v1beta1 - {{- else -}} +{{- else -}} apiVersion: extensions/v1beta1 - {{- end }} +{{- end }} kind: Ingress metadata: name: {{ $fullName }} labels: {{- include "dendrite.labels" . | nindent 4 }} annotations: + {{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }} + kubernetes.io/ingress.class: {{ .Values.ingress.className }} + {{- end }} {{- with .Values.ingress.annotations }} {{- toYaml . | nindent 4 }} {{- end }} @@ -26,7 +27,7 @@ spec: {{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }} ingressClassName: {{ .Values.ingress.className }} {{- end }} - {{- if .Values.ingress.tls }} + {{- if kindIs "slice" .Values.ingress.tls }} tls: {{- range .Values.ingress.tls }} - hosts: @@ -35,8 +36,16 @@ spec: {{- end }} secretName: {{ .secretName }} {{- end }} + {{- else if .Values.ingress.tls.generate }} + tls: + - hosts: + {{- range $allHosts }} + - {{ . | quote }} + {{- end }} + secretName: {{ $fullName }}-ingress-tls {{- end }} rules: + {{- if .Values.ingress.hostName }} - host: {{ .Values.ingress.hostName | quote }} http: paths: @@ -47,9 +56,60 @@ spec: service: name: {{ $fullName }} port: - number: {{ $svcPort }} + name: http + {{- else }} + serviceName: {{ $fullName }} + servicePort: http + {{- end }} + {{- else }} + - host: {{ $serverNameHost | quote }} + http: + paths: + - path: /.well-known/matrix + pathType: Prefix + backend: + {{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }} + service: + name: {{ $fullName }} + port: + name: http {{- else }} serviceName: {{ $fullName }} - servicePort: {{ $svcPort }} + servicePort: http {{- end }} - {{- end }} \ No newline at end of file + - host: {{ $wellKnownServerHost | quote }} + http: + paths: + {{- range list "/_matrix/key" "/_matrix/federation" }} + - path: {{ . | quote }} + pathType: Prefix + backend: + {{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }} + service: + name: {{ $fullName }} + port: + name: http + {{- else }} + serviceName: {{ $fullName }} + servicePort: http + {{- end }} + {{- end }} + - host: {{ $wellKnownClientHost | quote }} + http: + paths: + {{- range list "/_matrix/client" "/_matrix/media" }} + - path: {{ . | quote }} + pathType: Prefix + backend: + {{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }} + service: + name: {{ $fullName }} + port: + name: http + {{- else }} + serviceName: {{ $fullName }} + servicePort: http + {{- end }} + {{- end }} + {{- end }} +{{- end }} diff --git a/helm/dendrite/templates/prometheus-rules.yaml b/helm/dendrite/templates/prometheus-rules.yaml index 6693a4ed9d..dc6c12cf9d 100644 --- a/helm/dendrite/templates/prometheus-rules.yaml +++ b/helm/dendrite/templates/prometheus-rules.yaml @@ -6,7 +6,9 @@ metadata: name: {{ include "dendrite.fullname" . }} labels: {{- include "dendrite.labels" . | nindent 4 }} - {{- toYaml .Values.prometheus.rules.labels | nindent 4 }} + {{- with .Values.prometheus.rules.labels }} + {{- . | toYaml | nindent 4 }} + {{- end }} spec: groups: {{- if .Values.prometheus.rules.additionalRules }} diff --git a/helm/dendrite/templates/pvc.yaml b/helm/dendrite/templates/pvc.yaml index 897957e600..88eff3bede 100644 --- a/helm/dendrite/templates/pvc.yaml +++ b/helm/dendrite/templates/pvc.yaml @@ -12,7 +12,7 @@ spec: resources: requests: storage: {{ .Values.persistence.media.capacity }} - storageClassName: {{ .Values.persistence.storageClass }} + storageClassName: {{ default .Values.persistence.storageClass .Values.persistence.media.storageClass }} {{ end }} {{ if not .Values.persistence.jetstream.existingClaim }} --- @@ -28,7 +28,7 @@ spec: resources: requests: storage: {{ .Values.persistence.jetstream.capacity }} - storageClassName: {{ .Values.persistence.storageClass }} + storageClassName: {{ default .Values.persistence.storageClass .Values.persistence.jetstream.storageClass }} {{ end }} {{ if not .Values.persistence.search.existingClaim }} --- @@ -44,5 +44,5 @@ spec: resources: requests: storage: {{ .Values.persistence.search.capacity }} - storageClassName: {{ .Values.persistence.storageClass }} -{{ end }} \ No newline at end of file + storageClassName: {{ default .Values.persistence.storageClass .Values.persistence.search.storageClass }} +{{ end }} diff --git a/helm/dendrite/templates/servicemonitor.yaml b/helm/dendrite/templates/servicemonitor.yaml index 3819c7d020..4602140f87 100644 --- a/helm/dendrite/templates/servicemonitor.yaml +++ b/helm/dendrite/templates/servicemonitor.yaml @@ -9,7 +9,9 @@ metadata: name: {{ include "dendrite.fullname" . }} labels: {{- include "dendrite.labels" . | nindent 4 }} - {{- toYaml .Values.prometheus.servicemonitor.labels | nindent 4 }} + {{- with .Values.prometheus.servicemonitor.labels }} + {{- . | toYaml | nindent 4 }} + {{- end }} spec: endpoints: - port: http diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml index 41ec1c3906..2b009c7d64 100644 --- a/helm/dendrite/values.yaml +++ b/helm/dendrite/values.yaml @@ -19,29 +19,38 @@ signing_key: resources: requests: memory: "512Mi" - limits: memory: "4096Mi" persistence: - # -- The storage class to use for volume claims. Defaults to the - # cluster default storage class. + # -- The storage class to use for volume claims. + # Used unless specified at the specific component. + # Defaults to the cluster default storage class. storageClass: "" jetstream: # -- Use an existing volume claim for jetstream existingClaim: "" # -- PVC Storage Request for the jetstream volume capacity: "1Gi" + # -- The storage class to use for volume claims. + # Defaults to persistence.storageClass + storageClass: "" media: # -- Use an existing volume claim for media files existingClaim: "" # -- PVC Storage Request for the media volume capacity: "1Gi" + # -- The storage class to use for volume claims. + # Defaults to persistence.storageClass + storageClass: "" search: # -- Use an existing volume claim for the fulltext search index existingClaim: "" # -- PVC Storage Request for the search volume capacity: "1Gi" + # -- The storage class to use for volume claims. + # Defaults to persistence.storageClass + storageClass: "" # -- Add additional volumes to the Dendrite Pod extraVolumes: [] @@ -50,7 +59,6 @@ extraVolumes: [] # secret: # secretName: extra-config - # -- Configure additional mount points volumes in the Dendrite Pod extraVolumeMounts: [] # ex. @@ -212,7 +220,6 @@ dendrite_config: # - msc2836 (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) # - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946) - app_service_api: # -- Disable the validation of TLS certificates of appservices. This is # not recommended in production since it may allow appservice traffic @@ -359,14 +366,18 @@ postgresql: enabled: false ingress: - # -- Create an ingress for a monolith deployment + # -- Create an ingress for the deployment enabled: false - hosts: [] + # -- The ingressClass to use. Will be converted to annotation if not yet supported. className: "" - hostName: "" # -- Extra, custom annotations annotations: {} - + # -- The ingress hostname for your matrix server. + # Should align with the server_name and well_known_* hosts. + # If not set, generated from the dendrite_config values. + hostName: "" + # -- TLS configuration. Should contain information for the server_name and well-known hosts. + # Alternatively, set tls.generate=true to generate defaults based on the dendrite_config. tls: [] service: diff --git a/internal/caching/cache_federationevents.go b/internal/caching/cache_federationevents.go index 24af51bdce..fc1f5496e8 100644 --- a/internal/caching/cache_federationevents.go +++ b/internal/caching/cache_federationevents.go @@ -1,14 +1,15 @@ package caching import ( + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) // FederationCache contains the subset of functions needed for // a federation event cache. type FederationCache interface { - GetFederationQueuedPDU(eventNID int64) (event *gomatrixserverlib.HeaderedEvent, ok bool) - StoreFederationQueuedPDU(eventNID int64, event *gomatrixserverlib.HeaderedEvent) + GetFederationQueuedPDU(eventNID int64) (event *types.HeaderedEvent, ok bool) + StoreFederationQueuedPDU(eventNID int64, event *types.HeaderedEvent) EvictFederationQueuedPDU(eventNID int64) GetFederationQueuedEDU(eventNID int64) (event *gomatrixserverlib.EDU, ok bool) @@ -16,11 +17,11 @@ type FederationCache interface { EvictFederationQueuedEDU(eventNID int64) } -func (c Caches) GetFederationQueuedPDU(eventNID int64) (*gomatrixserverlib.HeaderedEvent, bool) { +func (c Caches) GetFederationQueuedPDU(eventNID int64) (*types.HeaderedEvent, bool) { return c.FederationPDUs.Get(eventNID) } -func (c Caches) StoreFederationQueuedPDU(eventNID int64, event *gomatrixserverlib.HeaderedEvent) { +func (c Caches) StoreFederationQueuedPDU(eventNID int64, event *types.HeaderedEvent) { c.FederationPDUs.Set(eventNID, event) } diff --git a/internal/caching/cache_roomevents.go b/internal/caching/cache_roomevents.go index 14b6c3af86..e8bbe208e2 100644 --- a/internal/caching/cache_roomevents.go +++ b/internal/caching/cache_roomevents.go @@ -2,22 +2,21 @@ package caching import ( "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) // RoomServerEventsCache contains the subset of functions needed for // a roomserver event cache. type RoomServerEventsCache interface { - GetRoomServerEvent(eventNID types.EventNID) (*gomatrixserverlib.Event, bool) - StoreRoomServerEvent(eventNID types.EventNID, event *gomatrixserverlib.Event) + GetRoomServerEvent(eventNID types.EventNID) (*types.HeaderedEvent, bool) + StoreRoomServerEvent(eventNID types.EventNID, event *types.HeaderedEvent) InvalidateRoomServerEvent(eventNID types.EventNID) } -func (c Caches) GetRoomServerEvent(eventNID types.EventNID) (*gomatrixserverlib.Event, bool) { +func (c Caches) GetRoomServerEvent(eventNID types.EventNID) (*types.HeaderedEvent, bool) { return c.RoomServerEvents.Get(int64(eventNID)) } -func (c Caches) StoreRoomServerEvent(eventNID types.EventNID, event *gomatrixserverlib.Event) { +func (c Caches) StoreRoomServerEvent(eventNID types.EventNID, event *types.HeaderedEvent) { c.RoomServerEvents.Set(int64(eventNID), event) } diff --git a/internal/caching/cache_serverkeys.go b/internal/caching/cache_serverkeys.go index cffa101d5f..7400b868cf 100644 --- a/internal/caching/cache_serverkeys.go +++ b/internal/caching/cache_serverkeys.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // ServerKeyCache contains the subset of functions needed for @@ -14,7 +15,7 @@ type ServerKeyCache interface { // The timestamp should be the timestamp of the event that is being // verified. We will not return keys from the cache that are not valid // at this timestamp. - GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest, timestamp gomatrixserverlib.Timestamp) (response gomatrixserverlib.PublicKeyLookupResult, ok bool) + GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest, timestamp spec.Timestamp) (response gomatrixserverlib.PublicKeyLookupResult, ok bool) // request -> result is emulating gomatrixserverlib.StoreKeys: // https://github.com/matrix-org/gomatrixserverlib/blob/f69539c86ea55d1e2cc76fd8e944e2d82d30397c/keyring.go#L112 @@ -23,11 +24,11 @@ type ServerKeyCache interface { func (c Caches) GetServerKey( request gomatrixserverlib.PublicKeyLookupRequest, - timestamp gomatrixserverlib.Timestamp, + timestamp spec.Timestamp, ) (gomatrixserverlib.PublicKeyLookupResult, bool) { key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) val, found := c.ServerKeys.Get(key) - if found && !val.WasValidAt(timestamp, true) { + if found && !val.WasValidAt(timestamp, gomatrixserverlib.StrictValiditySignatureCheck) { // The key wasn't valid at the requested timestamp so don't // return it. The caller will have to work out what to do. c.ServerKeys.Unset(key) diff --git a/internal/caching/caches.go b/internal/caching/caches.go index a678632ebe..6bae60d59f 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -28,12 +28,12 @@ type Caches struct { ServerKeys Cache[string, gomatrixserverlib.PublicKeyLookupResult] // server name -> server keys RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID RoomServerRoomIDs Cache[types.RoomNID, string] // room NID -> room ID - RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event + RoomServerEvents Cache[int64, *types.HeaderedEvent] // event NID -> event RoomServerStateKeys Cache[types.EventStateKeyNID, string] // eventStateKey NID -> event state key RoomServerStateKeyNIDs Cache[string, types.EventStateKeyNID] // event state key -> eventStateKey NID RoomServerEventTypeNIDs Cache[string, types.EventTypeNID] // eventType -> eventType NID RoomServerEventTypes Cache[types.EventTypeNID, string] // eventType NID -> eventType - FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU + FederationPDUs Cache[int64, *types.HeaderedEvent] // queue NID -> PDU FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU SpaceSummaryRooms Cache[string, fclient.MSC2946SpacesResponse] // room ID -> space response LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID diff --git a/internal/caching/impl_ristretto.go b/internal/caching/impl_ristretto.go index 4656b6b7eb..00989b7601 100644 --- a/internal/caching/impl_ristretto.go +++ b/internal/caching/impl_ristretto.go @@ -103,8 +103,8 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm Prefix: roomIDsCache, MaxAge: maxAge, }, - RoomServerEvents: &RistrettoCostedCachePartition[int64, *gomatrixserverlib.Event]{ // event NID -> event - &RistrettoCachePartition[int64, *gomatrixserverlib.Event]{ + RoomServerEvents: &RistrettoCostedCachePartition[int64, *types.HeaderedEvent]{ // event NID -> event + &RistrettoCachePartition[int64, *types.HeaderedEvent]{ cache: cache, Prefix: roomEventsCache, MaxAge: maxAge, @@ -131,8 +131,8 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm Prefix: eventTypeNIDCache, MaxAge: maxAge, }, - FederationPDUs: &RistrettoCostedCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ // queue NID -> PDU - &RistrettoCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ + FederationPDUs: &RistrettoCostedCachePartition[int64, *types.HeaderedEvent]{ // queue NID -> PDU + &RistrettoCachePartition[int64, *types.HeaderedEvent]{ cache: cache, Prefix: federationPDUsCache, Mutable: true, diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index 984a3f5397..56ee576a01 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -21,15 +21,27 @@ import ( "time" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib" ) // ErrRoomNoExists is returned when trying to lookup the state of a room that // doesn't exist -var ErrRoomNoExists = errors.New("room does not exist") +var errRoomNoExists = fmt.Errorf("room does not exist") + +type ErrRoomNoExists struct{} + +func (e ErrRoomNoExists) Error() string { + return errRoomNoExists.Error() +} + +func (e ErrRoomNoExists) Unwrap() error { + return errRoomNoExists +} // QueryAndBuildEvent builds a Matrix event using the event builder and roomserver query // API client provided. If also fills roomserver query API response (if provided) @@ -39,54 +51,60 @@ var ErrRoomNoExists = errors.New("room does not exist") // Returns an error if something else went wrong func QueryAndBuildEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, cfg *config.Global, + proto *gomatrixserverlib.ProtoEvent, identity *fclient.SigningIdentity, evTime time.Time, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, -) (*gomatrixserverlib.HeaderedEvent, error) { +) (*types.HeaderedEvent, error) { if queryRes == nil { queryRes = &api.QueryLatestEventsAndStateResponse{} } - eventsNeeded, err := queryRequiredEventsForBuilder(ctx, builder, rsAPI, queryRes) + eventsNeeded, err := queryRequiredEventsForBuilder(ctx, proto, rsAPI, queryRes) if err != nil { // This can pass through a ErrRoomNoExists to the caller return nil, err } - return BuildEvent(ctx, builder, cfg, identity, evTime, eventsNeeded, queryRes) + return BuildEvent(ctx, proto, identity, evTime, eventsNeeded, queryRes) } // BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse // provided. func BuildEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, cfg *config.Global, + proto *gomatrixserverlib.ProtoEvent, identity *fclient.SigningIdentity, evTime time.Time, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, -) (*gomatrixserverlib.HeaderedEvent, error) { - if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil { +) (*types.HeaderedEvent, error) { + if err := addPrevEventsToEvent(proto, eventsNeeded, queryRes); err != nil { return nil, err } + verImpl, err := gomatrixserverlib.GetRoomVersion(queryRes.RoomVersion) + if err != nil { + return nil, err + } + builder := verImpl.NewEventBuilderFromProtoEvent(proto) + event, err := builder.Build( evTime, identity.ServerName, identity.KeyID, - identity.PrivateKey, queryRes.RoomVersion, + identity.PrivateKey, ) if err != nil { return nil, err } - return event.Headered(queryRes.RoomVersion), nil + return &types.HeaderedEvent{PDU: event}, nil } // queryRequiredEventsForBuilder queries the roomserver for auth/prev events needed for this builder. func queryRequiredEventsForBuilder( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, + proto *gomatrixserverlib.ProtoEvent, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.StateNeeded, error) { - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) + eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto) if err != nil { - return nil, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) + return nil, fmt.Errorf("gomatrixserverlib.StateNeededForProtoEvent: %w", err) } if len(eventsNeeded.Tuples()) == 0 { @@ -95,7 +113,7 @@ func queryRequiredEventsForBuilder( // Ask the roomserver for information about this room queryReq := api.QueryLatestEventsAndStateRequest{ - RoomID: builder.RoomID, + RoomID: proto.RoomID, StateToFetch: eventsNeeded.Tuples(), } return &eventsNeeded, rsAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes) @@ -103,17 +121,12 @@ func queryRequiredEventsForBuilder( // addPrevEventsToEvent fills out the prev_events and auth_events fields in builder func addPrevEventsToEvent( - builder *gomatrixserverlib.EventBuilder, + builder *gomatrixserverlib.ProtoEvent, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, ) error { if !queryRes.RoomExists { - return ErrRoomNoExists - } - - eventFormat, err := queryRes.RoomVersion.EventFormat() - if err != nil { - return fmt.Errorf("queryRes.RoomVersion.EventFormat: %w", err) + return ErrRoomNoExists{} } builder.Depth = queryRes.Depth @@ -121,7 +134,7 @@ func addPrevEventsToEvent( authEvents := gomatrixserverlib.NewAuthEvents(nil) for i := range queryRes.StateEvents { - err = authEvents.AddEvent(queryRes.StateEvents[i].Event) + err := authEvents.AddEvent(queryRes.StateEvents[i].PDU) if err != nil { return fmt.Errorf("authEvents.AddEvent: %w", err) } @@ -132,22 +145,7 @@ func addPrevEventsToEvent( return fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err) } - truncAuth, truncPrev := truncateAuthAndPrevEvents(refs, queryRes.LatestEvents) - switch eventFormat { - case gomatrixserverlib.EventFormatV1: - builder.AuthEvents = truncAuth - builder.PrevEvents = truncPrev - case gomatrixserverlib.EventFormatV2: - v2AuthRefs, v2PrevRefs := []string{}, []string{} - for _, ref := range truncAuth { - v2AuthRefs = append(v2AuthRefs, ref.EventID) - } - for _, ref := range truncPrev { - v2PrevRefs = append(v2PrevRefs, ref.EventID) - } - builder.AuthEvents = v2AuthRefs - builder.PrevEvents = v2PrevRefs - } + builder.AuthEvents, builder.PrevEvents = truncateAuthAndPrevEvents(refs, queryRes.LatestEvents) return nil } @@ -157,8 +155,8 @@ func addPrevEventsToEvent( // NOTSPEC: The limits here feel a bit arbitrary but they are currently // here because of https://github.com/matrix-org/matrix-doc/issues/2307 // and because Synapse will just drop events that don't comply. -func truncateAuthAndPrevEvents(auth, prev []gomatrixserverlib.EventReference) ( - truncAuth, truncPrev []gomatrixserverlib.EventReference, +func truncateAuthAndPrevEvents(auth, prev []string) ( + truncAuth, truncPrev []string, ) { truncAuth, truncPrev = auth, prev if len(truncAuth) > 10 { @@ -172,13 +170,22 @@ func truncateAuthAndPrevEvents(auth, prev []gomatrixserverlib.EventReference) ( // RedactEvent redacts the given event and sets the unsigned field appropriately. This should be used by // downstream components to the roomserver when an OutputTypeRedactedEvent occurs. -func RedactEvent(redactionEvent, redactedEvent *gomatrixserverlib.Event) error { +func RedactEvent(ctx context.Context, redactionEvent, redactedEvent gomatrixserverlib.PDU, querier api.QuerySenderIDAPI) error { // sanity check - if redactionEvent.Type() != gomatrixserverlib.MRoomRedaction { + if redactionEvent.Type() != spec.MRoomRedaction { return fmt.Errorf("RedactEvent: redactionEvent isn't a redaction event, is '%s'", redactionEvent.Type()) } redactedEvent.Redact() - if err := redactedEvent.SetUnsignedField("redacted_because", redactionEvent); err != nil { + validRoomID, err := spec.NewRoomID(redactionEvent.RoomID()) + if err != nil { + return err + } + senderID, err := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID()) + if err != nil { + return err + } + redactedBecause := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, *senderID, redactionEvent.StateKey()) + if err := redactedEvent.SetUnsignedField("redacted_because", redactedBecause); err != nil { return err } // NOTSPEC: sytest relies on this unspecced field existing :( diff --git a/internal/fulltext/bleve.go b/internal/fulltext/bleve.go index f7412470d8..d2807198af 100644 --- a/internal/fulltext/bleve.go +++ b/internal/fulltext/bleve.go @@ -23,6 +23,7 @@ import ( "github.com/blevesearch/bleve/v2" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/gomatrixserverlib/spec" // side effect imports to allow all possible languages _ "github.com/blevesearch/bleve/v2/analysis/lang/ar" @@ -47,7 +48,6 @@ import ( _ "github.com/blevesearch/bleve/v2/analysis/lang/sv" _ "github.com/blevesearch/bleve/v2/analysis/lang/tr" "github.com/blevesearch/bleve/v2/mapping" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/setup/config" ) @@ -79,9 +79,9 @@ func (i *IndexElement) SetContentType(v string) { switch v { case "m.room.message": i.ContentType = "content.body" - case gomatrixserverlib.MRoomName: + case spec.MRoomName: i.ContentType = "content.name" - case gomatrixserverlib.MRoomTopic: + case spec.MRoomTopic: i.ContentType = "content.topic" } } diff --git a/internal/fulltext/bleve_test.go b/internal/fulltext/bleve_test.go index a77c239372..decb5eccba 100644 --- a/internal/fulltext/bleve_test.go +++ b/internal/fulltext/bleve_test.go @@ -19,7 +19,7 @@ import ( "testing" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/internal/fulltext" @@ -77,7 +77,7 @@ func mustAddTestData(t *testing.T, fts *fulltext.Search, firstStreamPos int64) ( Content: "Roomname testing", StreamPosition: streamPos, } - e.SetContentType(gomatrixserverlib.MRoomName) + e.SetContentType(spec.MRoomName) batchItems = append(batchItems, e) e = fulltext.IndexElement{ EventID: util.RandomString(16), @@ -85,7 +85,7 @@ func mustAddTestData(t *testing.T, fts *fulltext.Search, firstStreamPos int64) ( Content: "Room topic fulltext", StreamPosition: streamPos, } - e.SetContentType(gomatrixserverlib.MRoomTopic) + e.SetContentType(spec.MRoomTopic) batchItems = append(batchItems, e) if err := fts.Index(batchItems...); err != nil { t.Fatalf("failed to batch insert elements: %v", err) diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index d6c79e989b..802ff81871 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -21,17 +21,17 @@ import ( ) const ( - // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent + // KindNewEventPersisted is a hook which is called with *types.HeaderedEvent // It is run when a new event is persisted in the roomserver. // Usage: // hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { ... }) KindNewEventPersisted = "new_event_persisted" - // KindNewEventReceived is a hook which is called with *gomatrixserverlib.HeaderedEvent + // KindNewEventReceived is a hook which is called with *types.HeaderedEvent // It is run before a new event is processed by the roomserver. This hook can be used // to modify the event before it is persisted by adding data to `unsigned`. // Usage: // hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) { - // ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent) + // ev := headeredEvent.(*types.HeaderedEvent) // _ = ev.SetUnsignedField("key", "val") // }) KindNewEventReceived = "new_event_received" diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index f7e739a870..c8af1d26cd 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -31,9 +31,9 @@ import ( "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // BasicAuth is used for authorization on /metrics handlers @@ -101,7 +101,7 @@ func MakeAuthAPI( if !opts.GuestAccessAllowed && device.AccountType == userapi.AccountTypeGuest { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.GuestAccessForbidden("Guest access not allowed"), + JSON: spec.GuestAccessForbidden("Guest access not allowed"), } } @@ -177,7 +177,7 @@ func MakeAdminAPI( if device.AccountType != userapi.AccountTypeAdmin { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("This API can only be used by admin users."), + JSON: spec.Forbidden("This API can only be used by admin users."), } } return f(req, device) diff --git a/internal/httputil/rate_limiting.go b/internal/httputil/rate_limiting.go index dab36481e7..0b040d7f37 100644 --- a/internal/httputil/rate_limiting.go +++ b/internal/httputil/rate_limiting.go @@ -5,9 +5,9 @@ import ( "sync" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -118,7 +118,7 @@ func (l *RateLimits) Limit(req *http.Request, device *userapi.Device) *util.JSON // We hit the rate limit. Tell the client to back off. return &util.JSONResponse{ Code: http.StatusTooManyRequests, - JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()), + JSON: spec.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()), } } diff --git a/internal/httputil/routing.go b/internal/httputil/routing.go index c733c8ce7b..2052c798f5 100644 --- a/internal/httputil/routing.go +++ b/internal/httputil/routing.go @@ -15,10 +15,12 @@ package httputil import ( + "encoding/json" "net/http" "net/url" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib/spec" ) // URLDecodeMapValues is a function that iterates through each of the items in a @@ -66,13 +68,15 @@ func NewRouters() Routers { var NotAllowedHandler = WrapHandlerInCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusMethodNotAllowed) w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell + unrecognizedErr, _ := json.Marshal(spec.Unrecognized("Unrecognized request")) // nolint:misspell + _, _ = w.Write(unrecognizedErr) // nolint:misspell })) var NotFoundCORSHandler = WrapHandlerInCORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell + unrecognizedErr, _ := json.Marshal(spec.Unrecognized("Unrecognized request")) // nolint:misspell + _, _ = w.Write(unrecognizedErr) // nolint:misspell })) func (r *Routers) configureHTTPErrors() { diff --git a/internal/pushrules/default.go b/internal/pushrules/default.go index 9969855148..202a10d79b 100644 --- a/internal/pushrules/default.go +++ b/internal/pushrules/default.go @@ -1,12 +1,10 @@ package pushrules -import ( - "github.com/matrix-org/gomatrixserverlib" -) +import "github.com/matrix-org/gomatrixserverlib/spec" // DefaultAccountRuleSets is the complete set of default push rules // for an account. -func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets { +func DefaultAccountRuleSets(localpart string, serverName spec.ServerName) *AccountRuleSets { return &AccountRuleSets{ Global: *DefaultGlobalRuleSet(localpart, serverName), } @@ -14,7 +12,7 @@ func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.Serve // DefaultGlobalRuleSet returns the default ruleset for a given (fully // qualified) MXID. -func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet { +func DefaultGlobalRuleSet(localpart string, serverName spec.ServerName) *RuleSet { return &RuleSet{ Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)), Content: defaultContentRules(localpart), diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index fc8e0f1745..28dea97c46 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // A RuleSetEvaluator encapsulates context to evaluate an event @@ -27,7 +28,7 @@ type EvaluationContext interface { // HasPowerLevel returns whether the user has at least the given // power in the room of the current event. - HasPowerLevel(userID, levelKey string) (bool, error) + HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error) } // A kindAndRules is just here to simplify iteration of the (ordered) @@ -53,7 +54,7 @@ func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluat // MatchEvent returns the first matching rule. Returns nil if there // was no match rule. -func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) { +func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) (*Rule, error) { // TODO: server-default rules have lower priority than user rules, // but they are stored together with the user rules. It's a bit // unclear what the specification (11.14.1.4 Predefined rules) @@ -68,7 +69,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, if rule.Default != defRules { continue } - ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec) + ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec, userIDForSender) if err != nil { return nil, err } @@ -83,7 +84,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, return nil, nil } -func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { +func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext, userIDForSender spec.UserIDForSender) (bool, error) { if !rule.Enabled { return false, nil } @@ -113,14 +114,23 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu return rule.RuleID == event.RoomID(), nil case SenderKind: - return rule.RuleID == event.Sender(), nil + userID := "" + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return false, err + } + sender, err := userIDForSender(*validRoomID, event.SenderID()) + if err == nil { + userID = sender.String() + } + return rule.RuleID == userID, nil default: return false, nil } } -func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { +func conditionMatches(cond *Condition, event gomatrixserverlib.PDU, ec EvaluationContext) (bool, error) { switch cond.Kind { case EventMatchCondition: if cond.Pattern == nil { @@ -143,14 +153,14 @@ func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec Evalua return cmp(n), nil case SenderNotificationPermissionCondition: - return ec.HasPowerLevel(event.Sender(), cond.Key) + return ec.HasPowerLevel(event.SenderID(), cond.Key) default: return false, nil } } -func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) { +func patternMatches(key, pattern string, event gomatrixserverlib.PDU) (bool, error) { // It doesn't make sense for an empty pattern to match anything. if pattern == "" { return false, nil diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index ca8ae55192..a4ccc3d0fa 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -5,8 +5,13 @@ import ( "github.com/google/go-cmp/cmp" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) +func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + func TestRuleSetEvaluatorMatchEvent(t *testing.T) { ev := mustEventFromJSON(t, `{}`) defaultEnabled := &Rule{ @@ -29,7 +34,7 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) { Name string RuleSet RuleSet Want *Rule - Event *gomatrixserverlib.Event + Event gomatrixserverlib.PDU }{ {"empty", RuleSet{}, nil, ev}, {"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled, ev}, @@ -45,7 +50,7 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) { for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet) - got, err := rse.MatchEvent(tst.Event) + got, err := rse.MatchEvent(tst.Event, UserIDForSender) if err != nil { t.Fatalf("MatchEvent failed: %v", err) } @@ -68,7 +73,7 @@ func TestRuleMatches(t *testing.T) { {"emptyOverride", OverrideKind, emptyRule, `{}`, true}, {"emptyContent", ContentKind, emptyRule, `{}`, false}, {"emptyRoom", RoomKind, emptyRule, `{}`, true}, - {"emptySender", SenderKind, emptyRule, `{}`, true}, + {"emptySender", SenderKind, emptyRule, `{"room_id":"!room:example.com"}`, true}, {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, {"disabled", OverrideKind, Rule{}, `{}`, false}, @@ -82,15 +87,15 @@ func TestRuleMatches(t *testing.T) { {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, - {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, - {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, + {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true}, + {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false}, - {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true}, - {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false}, + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com","room_id":"!room:example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com","room_id":"!room:example.com"}`, false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { - got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil) + got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil, UserIDForSender) if err != nil { t.Fatalf("ruleMatches failed: %v", err) } @@ -153,8 +158,8 @@ type fakeEvaluationContext struct{ memberCount int } func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" } func (f fakeEvaluationContext) RoomMemberCount() (int, error) { return f.memberCount, nil } -func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) { - return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil +func (fakeEvaluationContext) HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error) { + return senderID == "@poweruser:example.com" && levelKey == "powerlevel", nil } func TestPatternMatches(t *testing.T) { @@ -188,8 +193,8 @@ func TestPatternMatches(t *testing.T) { } } -func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event { - ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7) +func mustEventFromJSON(t *testing.T, json string) gomatrixserverlib.PDU { + ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV7).NewEventFromTrustedJSON([]byte(json), false) if err != nil { t.Fatal(err) } diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go index f50c51bd7b..b54ec3fb0c 100644 --- a/internal/pushrules/validate.go +++ b/internal/pushrules/validate.go @@ -10,6 +10,10 @@ import ( func ValidateRule(kind Kind, rule *Rule) []error { var errs []error + if len(rule.RuleID) > 0 && rule.RuleID[:1] == "." { + errs = append(errs, fmt.Errorf("invalid rule ID: rule can not start with a dot")) + } + if !validRuleIDRE.MatchString(rule.RuleID) { errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID)) } diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index 60c02b1297..5bf7d819c0 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -21,14 +21,15 @@ import ( "sync" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" syncTypes "github.com/matrix-org/dendrite/syncapi/types" userAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" @@ -58,7 +59,7 @@ type TxnReq struct { gomatrixserverlib.Transaction rsAPI api.FederationRoomserverAPI userAPI userAPI.FederationUserAPI - ourServerName gomatrixserverlib.ServerName + ourServerName spec.ServerName keys gomatrixserverlib.JSONVerifier roomsMu *MutexByRoom producer *producers.SyncAPIProducer @@ -68,16 +69,16 @@ type TxnReq struct { func NewTxnReq( rsAPI api.FederationRoomserverAPI, userAPI userAPI.FederationUserAPI, - ourServerName gomatrixserverlib.ServerName, + ourServerName spec.ServerName, keys gomatrixserverlib.JSONVerifier, roomsMu *MutexByRoom, producer *producers.SyncAPIProducer, inboundPresenceEnabled bool, pdus []json.RawMessage, edus []gomatrixserverlib.EDU, - origin gomatrixserverlib.ServerName, + origin spec.ServerName, transactionID gomatrixserverlib.TransactionID, - destination gomatrixserverlib.ServerName, + destination spec.ServerName, ) TxnReq { t := TxnReq{ rsAPI: rsAPI, @@ -114,14 +115,13 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut if v, ok := roomVersions[roomID]; ok { return v } - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) + roomVersion, err := t.rsAPI.QueryRoomVersionForRoom(ctx, roomID) + if err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", roomID) return "" } - roomVersions[roomID] = verRes.RoomVersion - return verRes.RoomVersion + roomVersions[roomID] = roomVersion + return roomVersion } for _, pdu := range t.PDUs { @@ -136,7 +136,11 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut continue } roomVersion := getRoomVersion(header.RoomID) - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) + if err != nil { + continue + } + event, err := verImpl.NewEventFromUntrustedJSON(pdu) if err != nil { if _, ok := err.(gomatrixserverlib.BadJSONError); ok { // Room version 6 states that homeservers should strictly enforce canonical JSON @@ -148,13 +152,13 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut // See https://github.com/matrix-org/synapse/issues/7543 return nil, &util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON("PDU contains bad JSON"), + JSON: spec.BadJSON("PDU contains bad JSON"), } } util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) continue } - if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + if event.Type() == spec.MRoomCreate && event.StateKeyEquals("") { continue } if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { @@ -163,7 +167,9 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = fclient.PDUResult{ Error: err.Error(), @@ -178,8 +184,8 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut ctx, t.rsAPI, api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - event.Headered(roomVersion), + []*rstypes.HeaderedEvent{ + {PDU: event}, }, t.Destination, t.Origin, @@ -207,7 +213,7 @@ func (t *TxnReq) processEDUs(ctx context.Context) { for _, e := range t.EDUs { EDUCountTotal.Inc() switch e.Type { - case gomatrixserverlib.MTyping: + case spec.MTyping: // https://matrix.org/docs/spec/server_server/latest#typing-notifications var typingPayload struct { RoomID string `json:"room_id"` @@ -228,7 +234,7 @@ func (t *TxnReq) processEDUs(ctx context.Context) { if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") } - case gomatrixserverlib.MDirectToDevice: + case spec.MDirectToDevice: // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema var directPayload gomatrixserverlib.ToDeviceMessage if err := json.Unmarshal(e.Content, &directPayload); err != nil { @@ -255,12 +261,12 @@ func (t *TxnReq) processEDUs(ctx context.Context) { } } } - case gomatrixserverlib.MDeviceListUpdate: + case spec.MDeviceListUpdate: if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { sentry.CaptureException(err) util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") } - case gomatrixserverlib.MReceipt: + case spec.MReceipt: // https://matrix.org/docs/spec/server_server/r0.1.4#receipts payload := map[string]types.FederationReceiptMRead{} @@ -296,7 +302,7 @@ func (t *TxnReq) processEDUs(ctx context.Context) { sentry.CaptureException(err) logrus.WithError(err).Errorf("Failed to process signing key update") } - case gomatrixserverlib.MPresence: + case spec.MPresence: if t.inboundPresenceEnabled { if err := t.processPresence(ctx, e); err != nil { logrus.WithError(err).Errorf("Failed to process presence update") @@ -336,7 +342,7 @@ func (t *TxnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) e // processReceiptEvent sends receipt events to JetStream func (t *TxnReq) processReceiptEvent(ctx context.Context, userID, roomID, receiptType string, - timestamp gomatrixserverlib.Timestamp, + timestamp spec.Timestamp, eventIDs []string, ) error { if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index c152eb2856..ffc1cd89ab 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -23,6 +23,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" "go.uber.org/atomic" @@ -30,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/producers" rsAPI "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -39,8 +41,8 @@ import ( ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") - testDestination = gomatrixserverlib.ServerName("white.orchard") + testOrigin = spec.ServerName("kaer.morhen") + testDestination = spec.ServerName("white.orchard") ) var ( @@ -58,27 +60,28 @@ var ( } testEvent = []byte(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`) testRoomVersion = gomatrixserverlib.RoomVersionV1 - testEvents = []*gomatrixserverlib.HeaderedEvent{} - testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) + testEvents = []*rstypes.HeaderedEvent{} + testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*rstypes.HeaderedEvent) ) type FakeRsAPI struct { rsAPI.RoomserverInternalAPI - shouldFailQuery bool - bannedFromRoom bool - shouldEventsFail bool + shouldFailQuery bool + bannedFromRoom bool +} + +func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func (r *FakeRsAPI) QueryRoomVersionForRoom( ctx context.Context, - req *rsAPI.QueryRoomVersionForRoomRequest, - res *rsAPI.QueryRoomVersionForRoomResponse, -) error { + roomID string, +) (gomatrixserverlib.RoomVersion, error) { if r.shouldFailQuery { - return fmt.Errorf("Failure") + return "", fmt.Errorf("Failure") } - res.RoomVersion = gomatrixserverlib.RoomVersionV10 - return nil + return gomatrixserverlib.RoomVersionV10, nil } func (r *FakeRsAPI) QueryServerBannedFromRoom( @@ -98,11 +101,7 @@ func (r *FakeRsAPI) InputRoomEvents( ctx context.Context, req *rsAPI.InputRoomEventsRequest, res *rsAPI.InputRoomEventsResponse, -) error { - if r.shouldEventsFail { - return fmt.Errorf("Failure") - } - return nil +) { } func TestEmptyTransactionRequest(t *testing.T) { @@ -184,18 +183,6 @@ func TestProcessTransactionRequestPDUInvalidSignature(t *testing.T) { } } -func TestProcessTransactionRequestPDUSendFail(t *testing.T) { - keyRing := &test.NopJSONVerifier{} - txn := NewTxnReq(&FakeRsAPI{shouldEventsFail: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") - txnRes, jsonRes := txn.ProcessTransaction(context.Background()) - - assert.Nil(t, jsonRes) - assert.Equal(t, 1, len(txnRes.PDUs)) - for _, result := range txnRes.PDUs { - assert.NotEmpty(t, result.Error) - } -} - func createTransactionWithEDU(ctx *process.ProcessContext, edus []gomatrixserverlib.EDU) (TxnReq, nats.JetStreamContext, *config.Dendrite) { cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ @@ -235,7 +222,7 @@ func TestProcessTransactionRequestEDUTyping(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.typing"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") edus := []gomatrixserverlib.EDU{badEDU, edu} ctx := process.NewProcessContext() @@ -301,7 +288,7 @@ func TestProcessTransactionRequestEDUToDevice(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.direct_to_device"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") edus := []gomatrixserverlib.EDU{badEDU, edu} ctx := process.NewProcessContext() @@ -378,7 +365,7 @@ func TestProcessTransactionRequestEDUDeviceListUpdate(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.device_list_update"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") edus := []gomatrixserverlib.EDU{badEDU, edu} ctx := process.NewProcessContext() @@ -441,7 +428,7 @@ func TestProcessTransactionRequestEDUReceipt(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.receipt"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") badUser := gomatrixserverlib.EDU{Type: "m.receipt"} if badUser.Content, err = json.Marshal(map[string]interface{}{ roomID: map[string]interface{}{ @@ -519,7 +506,7 @@ func TestProcessTransactionRequestEDUSigningKeyUpdate(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.signing_key_update"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") edus := []gomatrixserverlib.EDU{badEDU, edu} ctx := process.NewProcessContext() @@ -576,7 +563,7 @@ func TestProcessTransactionRequestEDUPresence(t *testing.T) { t.Errorf("failed to marshal EDU JSON") } badEDU := gomatrixserverlib.EDU{Type: "m.presence"} - badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badEDU.Content = spec.RawJSON("badjson") edus := []gomatrixserverlib.EDU{badEDU, edu} ctx := process.NewProcessContext() @@ -632,11 +619,11 @@ func TestProcessTransactionRequestEDUUnhandled(t *testing.T) { func init() { for _, j := range testData { - e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) + e, err := gomatrixserverlib.MustGetRoomVersion(testRoomVersion).NewEventFromTrustedJSON(j, false) if err != nil { panic("cannot load test data: " + err.Error()) } - h := e.Headered(testRoomVersion) + h := &rstypes.HeaderedEvent{PDU: e} testEvents = append(testEvents, h) if e.StateKey() != nil { testStateEvents[gomatrixserverlib.StateKeyTuple{ @@ -655,16 +642,19 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse } +func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + func (t *testRoomserverAPI) InputRoomEvents( ctx context.Context, request *rsAPI.InputRoomEventsRequest, response *rsAPI.InputRoomEventsResponse, -) error { +) { t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) for _, ire := range request.InputRoomEvents { fmt.Println("InputRoomEvents: ", ire.Event.EventID()) } - return nil } // Query the latest events and state for a room from the room server. @@ -721,11 +711,9 @@ func (t *testRoomserverAPI) QueryServerJoinedToRoom( // Asks for the room version for a given room. func (t *testRoomserverAPI) QueryRoomVersionForRoom( ctx context.Context, - request *rsAPI.QueryRoomVersionForRoomRequest, - response *rsAPI.QueryRoomVersionForRoomResponse, -) error { - response.RoomVersion = testRoomVersion - return nil + roomID string, +) (gomatrixserverlib.RoomVersion, error) { + return testRoomVersion, nil } func (t *testRoomserverAPI) QueryServerBannedFromRoom( @@ -780,7 +768,7 @@ NextPDU: } } -func assertInputRoomEvents(t *testing.T, got []rsAPI.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { +func assertInputRoomEvents(t *testing.T, got []rsAPI.InputRoomEvent, want []*rstypes.HeaderedEvent) { for _, g := range got { fmt.Println("GOT ", g.Event.EventID()) } @@ -804,7 +792,7 @@ func TestBasicTransaction(t *testing.T) { } txn := mustCreateTransaction(rsAPI, pdus) mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*rstypes.HeaderedEvent{testEvents[len(testEvents)-1]}) } // The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver @@ -817,5 +805,5 @@ func TestTransactionFailAuthChecks(t *testing.T) { txn := mustCreateTransaction(rsAPI, pdus) mustProcessTransaction(t, txn, []string{}) // expect message to be sent to the roomserver - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*rstypes.HeaderedEvent{testEvents[len(testEvents)-1]}) } diff --git a/internal/validate.go b/internal/validate.go index 0461b897ed..99088f2403 100644 --- a/internal/validate.go +++ b/internal/validate.go @@ -20,8 +20,7 @@ import ( "net/http" "regexp" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -58,19 +57,19 @@ func PasswordResponse(err error) *util.JSONResponse { case ErrPasswordWeak: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()), + JSON: spec.WeakPassword(ErrPasswordWeak.Error()), } case ErrPasswordTooLong: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()), + JSON: spec.BadJSON(ErrPasswordTooLong.Error()), } } return nil } // ValidateUsername returns an error if the username is invalid -func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error { +func ValidateUsername(localpart string, domain spec.ServerName) error { // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { return ErrUsernameTooLong @@ -88,19 +87,19 @@ func UsernameResponse(err error) *util.JSONResponse { case ErrUsernameTooLong: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), + JSON: spec.BadJSON(err.Error()), } case ErrUsernameInvalid, ErrUsernameUnderscore: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(err.Error()), + JSON: spec.InvalidUsername(err.Error()), } } return nil } // ValidateApplicationServiceUsername returns an error if the username is invalid for an application service -func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error { +func ValidateApplicationServiceUsername(localpart string, domain spec.ServerName) error { if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { return ErrUsernameTooLong } else if !validUsernameRegex.MatchString(localpart) { diff --git a/internal/validate_test.go b/internal/validate_test.go index d0ad047079..e3a10178fa 100644 --- a/internal/validate_test.go +++ b/internal/validate_test.go @@ -6,8 +6,7 @@ import ( "strings" "testing" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -22,13 +21,13 @@ func Test_validatePassword(t *testing.T) { name: "password too short", password: "shortpw", wantError: ErrPasswordWeak, - wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())}, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: spec.WeakPassword(ErrPasswordWeak.Error())}, }, { name: "password too long", password: strings.Repeat("a", maxPasswordLength+1), wantError: ErrPasswordTooLong, - wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())}, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: spec.BadJSON(ErrPasswordTooLong.Error())}, }, { name: "password OK", @@ -54,7 +53,7 @@ func Test_validateUsername(t *testing.T) { tests := []struct { name string localpart string - domain gomatrixserverlib.ServerName + domain spec.ServerName wantErr error wantJSON *util.JSONResponse }{ @@ -65,7 +64,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameInvalid, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + JSON: spec.InvalidUsername(ErrUsernameInvalid.Error()), }, }, { @@ -75,7 +74,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameInvalid, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + JSON: spec.InvalidUsername(ErrUsernameInvalid.Error()), }, }, { @@ -85,7 +84,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameTooLong, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()), + JSON: spec.BadJSON(ErrUsernameTooLong.Error()), }, }, { @@ -95,7 +94,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameUnderscore, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()), + JSON: spec.InvalidUsername(ErrUsernameUnderscore.Error()), }, }, { @@ -115,7 +114,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameInvalid, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + JSON: spec.InvalidUsername(ErrUsernameInvalid.Error()), }, }, { @@ -135,7 +134,7 @@ func Test_validateUsername(t *testing.T) { wantErr: ErrUsernameInvalid, wantJSON: &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + JSON: spec.InvalidUsername(ErrUsernameInvalid.Error()), }, }, { diff --git a/internal/version.go b/internal/version.go index 9075475891..c42b20390a 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 12 - VersionPatch = 0 + VersionMinor = 13 + VersionPatch = 1 VersionTag = "" // example: "rc1" ) diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 1ab2dfb28c..8fb1b6534e 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -30,14 +30,13 @@ import ( "sync" "unicode" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/thumbnailer" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -72,7 +71,7 @@ type downloadRequest struct { func Download( w http.ResponseWriter, req *http.Request, - origin gomatrixserverlib.ServerName, + origin spec.ServerName, mediaID types.MediaID, cfg *config.MediaAPI, db storage.Database, @@ -130,7 +129,7 @@ func Download( // TODO: Handle the fact we might have started writing the response dReq.jsonErrorResponse(w, util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Failed to download: " + err.Error()), + JSON: spec.NotFound("Failed to download: " + err.Error()), }) return } @@ -138,7 +137,7 @@ func Download( if metadata == nil { dReq.jsonErrorResponse(w, util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("File not found"), + JSON: spec.NotFound("File not found"), }) return } @@ -168,7 +167,7 @@ func (r *downloadRequest) Validate() *util.JSONResponse { if !mediaIDRegex.MatchString(string(r.MediaMetadata.MediaID)) { return &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("mediaId must be a non-empty string using only characters in %v", mediaIDCharacters)), + JSON: spec.NotFound(fmt.Sprintf("mediaId must be a non-empty string using only characters in %v", mediaIDCharacters)), } } // Note: the origin will be validated either by comparison to the configured server name of this homeserver @@ -176,7 +175,7 @@ func (r *downloadRequest) Validate() *util.JSONResponse { if r.MediaMetadata.Origin == "" { return &util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("serverName must be a non-empty string"), + JSON: spec.NotFound("serverName must be a non-empty string"), } } @@ -184,7 +183,7 @@ func (r *downloadRequest) Validate() *util.JSONResponse { if r.ThumbnailSize.Width <= 0 || r.ThumbnailSize.Height <= 0 { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("width and height must be greater than 0"), + JSON: spec.Unknown("width and height must be greater than 0"), } } // Default method to scale if not set @@ -194,7 +193,7 @@ func (r *downloadRequest) Validate() *util.JSONResponse { if r.ThumbnailSize.ResizeMethod != types.Crop && r.ThumbnailSize.ResizeMethod != types.Scale { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("method must be one of crop or scale"), + JSON: spec.Unknown("method must be one of crop or scale"), } } } diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 79e8308ae8..e0af4a911d 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -25,8 +25,8 @@ import ( "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -140,7 +140,7 @@ func makeDownloadAPI( } vars, _ := httputil.URLDecodeMapValues(mux.Vars(req)) - serverName := gomatrixserverlib.ServerName(vars["serverName"]) + serverName := spec.ServerName(vars["serverName"]) // For the purposes of loop avoidance, we will return a 404 if allow_remote is set to // false in the query string and the target server name isn't our own. diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index 2175648eaf..5ac1d076b8 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -26,7 +26,6 @@ import ( "path" "strings" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/thumbnailer" @@ -34,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -165,7 +165,7 @@ func (r *uploadRequest) doUpload( }).Warn("Error while transferring file") return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Failed to upload"), + JSON: spec.Unknown("Failed to upload"), } } @@ -184,8 +184,10 @@ func (r *uploadRequest) doUpload( if err != nil { fileutils.RemoveDir(tmpDir, r.Logger) r.Logger.WithError(err).Error("Error querying the database by hash.") - resErr := jsonerror.InternalServerError() - return &resErr + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if existingMetadata != nil { // The file already exists, delete the uploaded temporary file. @@ -194,8 +196,10 @@ func (r *uploadRequest) doUpload( mediaID, merr := r.generateMediaID(ctx, db) if merr != nil { r.Logger.WithError(merr).Error("Failed to generate media ID for existing file") - resErr := jsonerror.InternalServerError() - return &resErr + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // Then amend the upload metadata. @@ -217,8 +221,10 @@ func (r *uploadRequest) doUpload( if err != nil { fileutils.RemoveDir(tmpDir, r.Logger) r.Logger.WithError(err).Error("Failed to generate media ID for new upload") - resErr := jsonerror.InternalServerError() - return &resErr + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } @@ -239,7 +245,7 @@ func (r *uploadRequest) doUpload( func requestEntityTooLargeJSONResponse(maxFileSizeBytes config.FileSizeBytes) *util.JSONResponse { return &util.JSONResponse{ Code: http.StatusRequestEntityTooLarge, - JSON: jsonerror.Unknown(fmt.Sprintf("HTTP Content-Length is greater than the maximum allowed upload size (%v).", maxFileSizeBytes)), + JSON: spec.Unknown(fmt.Sprintf("HTTP Content-Length is greater than the maximum allowed upload size (%v).", maxFileSizeBytes)), } } @@ -251,7 +257,7 @@ func (r *uploadRequest) Validate(maxFileSizeBytes config.FileSizeBytes) *util.JS if strings.HasPrefix(string(r.MediaMetadata.UploadName), "~") { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("File name must not begin with '~'."), + JSON: spec.Unknown("File name must not begin with '~'."), } } // TODO: Validate filename - what are the valid characters? @@ -264,7 +270,7 @@ func (r *uploadRequest) Validate(maxFileSizeBytes config.FileSizeBytes) *util.JS if _, _, err := gomatrixserverlib.SplitID('@', string(r.MediaMetadata.UserID)); err != nil { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("user id must be in the form @localpart:domain"), + JSON: spec.BadJSON("user id must be in the form @localpart:domain"), } } } @@ -290,7 +296,7 @@ func (r *uploadRequest) storeFileAndMetadata( r.Logger.WithError(err).Error("Failed to move file.") return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Failed to upload"), + JSON: spec.Unknown("Failed to upload"), } } if duplicate { @@ -307,7 +313,7 @@ func (r *uploadRequest) storeFileAndMetadata( } return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown("Failed to upload"), + JSON: spec.Unknown("Failed to upload"), } } diff --git a/mediaapi/storage/interface.go b/mediaapi/storage/interface.go index d083be1eb4..cf3e7df571 100644 --- a/mediaapi/storage/interface.go +++ b/mediaapi/storage/interface.go @@ -18,7 +18,7 @@ import ( "context" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Database interface { @@ -28,12 +28,12 @@ type Database interface { type MediaRepository interface { StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error - GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) - GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) + GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName) (*types.MediaMetadata, error) + GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin spec.ServerName) (*types.MediaMetadata, error) } type Thumbnails interface { StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error - GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) - GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) + GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) + GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName) ([]*types.ThumbnailMetadata, error) } diff --git a/mediaapi/storage/postgres/media_repository_table.go b/mediaapi/storage/postgres/media_repository_table.go index 41cee48781..0583dd0175 100644 --- a/mediaapi/storage/postgres/media_repository_table.go +++ b/mediaapi/storage/postgres/media_repository_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const mediaSchema = ` @@ -88,7 +88,7 @@ func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) func (s *mediaStatements) InsertMedia( ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata, ) error { - mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + mediaMetadata.CreationTimestamp = spec.AsTimestamp(time.Now()) _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext( ctx, mediaMetadata.MediaID, @@ -104,7 +104,7 @@ func (s *mediaStatements) InsertMedia( } func (s *mediaStatements) SelectMedia( - ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin spec.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ MediaID: mediaID, @@ -124,7 +124,7 @@ func (s *mediaStatements) SelectMedia( } func (s *mediaStatements) SelectMediaByHash( - ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin spec.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ Base64Hash: mediaHash, diff --git a/mediaapi/storage/postgres/thumbnail_table.go b/mediaapi/storage/postgres/thumbnail_table.go index 7e07b476e4..8544855288 100644 --- a/mediaapi/storage/postgres/thumbnail_table.go +++ b/mediaapi/storage/postgres/thumbnail_table.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const thumbnailSchema = ` @@ -91,7 +91,7 @@ func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) { func (s *thumbnailStatements) InsertThumbnail( ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata, ) error { - thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + thumbnailMetadata.MediaMetadata.CreationTimestamp = spec.AsTimestamp(time.Now()) _, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext( ctx, thumbnailMetadata.MediaMetadata.MediaID, @@ -110,7 +110,7 @@ func (s *thumbnailStatements) SelectThumbnail( ctx context.Context, txn *sql.Tx, mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, + mediaOrigin spec.ServerName, width, height int, resizeMethod string, ) (*types.ThumbnailMetadata, error) { @@ -141,7 +141,7 @@ func (s *thumbnailStatements) SelectThumbnail( } func (s *thumbnailStatements) SelectThumbnails( - ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin spec.ServerName, ) ([]*types.ThumbnailMetadata, error) { rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext( ctx, mediaID, mediaOrigin, diff --git a/mediaapi/storage/shared/mediaapi.go b/mediaapi/storage/shared/mediaapi.go index c8d9ad6ab6..867405fb37 100644 --- a/mediaapi/storage/shared/mediaapi.go +++ b/mediaapi/storage/shared/mediaapi.go @@ -21,7 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Database struct { @@ -42,7 +42,7 @@ func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.M // GetMediaMetadata returns metadata about media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this media. -func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) { +func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName) (*types.MediaMetadata, error) { mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin) if err != nil && err == sql.ErrNoRows { return nil, nil @@ -53,7 +53,7 @@ func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, m // GetMediaMetadataByHash returns metadata about media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this media. -func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) { +func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin spec.ServerName) (*types.MediaMetadata, error) { mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin) if err != nil && err == sql.ErrNoRows { return nil, nil @@ -72,7 +72,7 @@ func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.T // GetThumbnail returns metadata about a specific thumbnail. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there is no metadata associated with this thumbnail. -func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) { +func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) { metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod) if err != nil { if err == sql.ErrNoRows { @@ -86,7 +86,7 @@ func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, media // GetThumbnails returns metadata about all thumbnails for a specific media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. // Returns nil metadata if there are no thumbnails associated with this media. -func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) { +func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin spec.ServerName) ([]*types.ThumbnailMetadata, error) { metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin) if err != nil { if err == sql.ErrNoRows { diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go index 78431967f2..625688cd77 100644 --- a/mediaapi/storage/sqlite3/media_repository_table.go +++ b/mediaapi/storage/sqlite3/media_repository_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const mediaSchema = ` @@ -91,7 +91,7 @@ func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) { func (s *mediaStatements) InsertMedia( ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata, ) error { - mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + mediaMetadata.CreationTimestamp = spec.AsTimestamp(time.Now()) _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext( ctx, mediaMetadata.MediaID, @@ -107,7 +107,7 @@ func (s *mediaStatements) InsertMedia( } func (s *mediaStatements) SelectMedia( - ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin spec.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ MediaID: mediaID, @@ -127,7 +127,7 @@ func (s *mediaStatements) SelectMedia( } func (s *mediaStatements) SelectMediaByHash( - ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin spec.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ Base64Hash: mediaHash, diff --git a/mediaapi/storage/sqlite3/thumbnail_table.go b/mediaapi/storage/sqlite3/thumbnail_table.go index 5ff2fece0e..259d55b73c 100644 --- a/mediaapi/storage/sqlite3/thumbnail_table.go +++ b/mediaapi/storage/sqlite3/thumbnail_table.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const thumbnailSchema = ` @@ -79,7 +79,7 @@ func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) { } func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error { - thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + thumbnailMetadata.MediaMetadata.CreationTimestamp = spec.AsTimestamp(time.Now()) _, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext( ctx, thumbnailMetadata.MediaMetadata.MediaID, @@ -98,7 +98,7 @@ func (s *thumbnailStatements) SelectThumbnail( ctx context.Context, txn *sql.Tx, mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, + mediaOrigin spec.ServerName, width, height int, resizeMethod string, ) (*types.ThumbnailMetadata, error) { @@ -130,7 +130,7 @@ func (s *thumbnailStatements) SelectThumbnail( func (s *thumbnailStatements) SelectThumbnails( ctx context.Context, txn *sql.Tx, mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, + mediaOrigin spec.ServerName, ) ([]*types.ThumbnailMetadata, error) { rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext( ctx, mediaID, mediaOrigin, diff --git a/mediaapi/storage/tables/interface.go b/mediaapi/storage/tables/interface.go index bf63bc6abe..2ff8039b4a 100644 --- a/mediaapi/storage/tables/interface.go +++ b/mediaapi/storage/tables/interface.go @@ -19,28 +19,28 @@ import ( "database/sql" "github.com/matrix-org/dendrite/mediaapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Thumbnails interface { InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error SelectThumbnail( ctx context.Context, txn *sql.Tx, - mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, + mediaID types.MediaID, mediaOrigin spec.ServerName, width, height int, resizeMethod string, ) (*types.ThumbnailMetadata, error) SelectThumbnails( ctx context.Context, txn *sql.Tx, mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, + mediaOrigin spec.ServerName, ) ([]*types.ThumbnailMetadata, error) } type MediaRepository interface { InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error - SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) + SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin spec.ServerName) (*types.MediaMetadata, error) SelectMediaByHash( ctx context.Context, txn *sql.Tx, - mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, + mediaHash types.Base64Hash, mediaOrigin spec.ServerName, ) (*types.MediaMetadata, error) } diff --git a/mediaapi/types/types.go b/mediaapi/types/types.go index ab28b34105..e1c29e0f66 100644 --- a/mediaapi/types/types.go +++ b/mediaapi/types/types.go @@ -18,7 +18,7 @@ import ( "sync" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // FileSizeBytes is a file size in bytes @@ -48,10 +48,10 @@ type MatrixUserID string // MediaMetadata is metadata associated with a media file type MediaMetadata struct { MediaID MediaID - Origin gomatrixserverlib.ServerName + Origin spec.ServerName ContentType ContentType FileSizeBytes FileSizeBytes - CreationTimestamp gomatrixserverlib.Timestamp + CreationTimestamp spec.Timestamp UploadName Filename Base64Hash Base64Hash UserID MatrixUserID diff --git a/relayapi/storage/storage_wasm.go b/relayapi/storage/storage_wasm.go index 5ab872f744..206296a2fc 100644 --- a/relayapi/storage/storage_wasm.go +++ b/relayapi/storage/storage_wasm.go @@ -28,7 +28,7 @@ func NewDatabase( conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.FederationCache, - isLocalServerName func(gomatrixserverlib.ServerName) bool, + isLocalServerName func(spec.ServerName) bool, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index b18daa3deb..b04828b692 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -23,7 +23,9 @@ import ( "strings" "sync" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -33,7 +35,7 @@ type ServerACLDatabase interface { // GetStateEvent returns the state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error - GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) } type ServerACLs struct { @@ -61,7 +63,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { continue } if state != nil { - acls.OnServerACLUpdate(state.Event) + acls.OnServerACLUpdate(state.PDU) } } return acls @@ -86,7 +88,7 @@ func compileACLRegex(orig string) (*regexp.Regexp, error) { return regexp.Compile(escaped) } -func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) { +func (s *ServerACLs) OnServerACLUpdate(state gomatrixserverlib.PDU) { acls := &serverACL{} if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil { logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs") @@ -120,7 +122,7 @@ func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) { s.acls[state.RoomID()] = acls } -func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool { +func (s *ServerACLs) IsServerBannedFromRoom(serverName spec.ServerName, roomID string) bool { s.aclsMutex.RLock() // First of all check if we have an ACL for this room. If we don't then // no servers are banned from the room. @@ -133,7 +135,7 @@ func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerN // Split the host and port apart. This is because the spec calls on us to // validate the hostname only in cases where the port is also present. if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil { - serverName = gomatrixserverlib.ServerName(serverNameOnly) + serverName = spec.ServerName(serverNameOnly) } // Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding // a /0 prefix length just to trick ParseCIDR into working. If we find that diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index 37892a44a1..c091cf6a32 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -14,7 +14,11 @@ package api -import "regexp" +import ( + "regexp" + + "github.com/matrix-org/gomatrixserverlib/spec" +) // SetRoomAliasRequest is a request to SetRoomAlias type SetRoomAliasRequest struct { @@ -62,7 +66,7 @@ type GetAliasesForRoomIDResponse struct { // RemoveRoomAliasRequest is a request to RemoveRoomAlias type RemoveRoomAliasRequest struct { // ID of the user removing the alias - UserID string `json:"user_id"` + SenderID spec.SenderID `json:"user_id"` // The room alias to remove Alias string `json:"alias"` } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index dda5bb5a4b..ab56529c5d 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -2,14 +2,48 @@ package api import ( "context" + "crypto/ed25519" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/roomserver/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) +// ErrInvalidID is an error returned if the userID is invalid +type ErrInvalidID struct { + Err error +} + +func (e ErrInvalidID) Error() string { + return e.Err.Error() +} + +// ErrNotAllowed is an error returned if the user is not allowed +// to execute some action (e.g. invite) +type ErrNotAllowed struct { + Err error +} + +func (e ErrNotAllowed) Error() string { + return e.Err.Error() +} + +type RestrictedJoinAPI interface { + CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) + InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) + RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) + QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) + QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error + UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, senderID spec.SenderID) (bool, error) + LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) +} + // RoomserverInputAPI is used to write events to the room server. type RoomserverInternalAPI interface { SyncRoomserverAPI @@ -17,6 +51,8 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI + QuerySenderIDAPI + UserRoomPrivateKeyCreator // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -35,12 +71,23 @@ type RoomserverInternalAPI interface { ) error } +type UserRoomPrivateKeyCreator interface { + // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. + GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) + StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error +} + type InputRoomEventsAPI interface { InputRoomEvents( ctx context.Context, req *InputRoomEventsRequest, res *InputRoomEventsResponse, - ) error + ) +} + +type QuerySenderIDAPI interface { + QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) + QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) } // Query the latest events and state for a room from the room server. @@ -70,6 +117,7 @@ type QueryEventsAPI interface { type SyncRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI + QuerySenderIDAPI // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine @@ -101,7 +149,7 @@ type SyncRoomserverAPI interface { ) error // QueryMembershipAtEvent queries the memberships at the given events. - // Returns a map from eventID to a slice of gomatrixserverlib.HeaderedEvent. + // Returns a map from eventID to a slice of types.HeaderedEvent. QueryMembershipAtEvent( ctx context.Context, request *QueryMembershipAtEventRequest, @@ -110,6 +158,7 @@ type SyncRoomserverAPI interface { } type AppserviceRoomserverAPI interface { + QuerySenderIDAPI // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // which room to use by querying the first events roomID. QueryEventsByID( @@ -136,53 +185,65 @@ type ClientRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI QueryEventsAPI + QuerySenderIDAPI + UserRoomPrivateKeyCreator QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error // QueryKnownUsers returns a list of users that we know about from our joined rooms. QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error - QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error + QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error + PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) // PerformRoomUpgrade upgrades a room to a newer version - PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error - PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error - PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error - PerformAdminPurgeRoom(ctx context.Context, req *PerformAdminPurgeRoomRequest, res *PerformAdminPurgeRoomResponse) error - PerformAdminDownloadState(ctx context.Context, req *PerformAdminDownloadStateRequest, res *PerformAdminDownloadStateResponse) error - PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error - PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error - PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error - PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error + PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) + PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) + PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) + PerformAdminPurgeRoom(ctx context.Context, roomID string) error + PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error + PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error) + PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error + PerformInvite(ctx context.Context, req *PerformInviteRequest) error + PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error) PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error - PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error + PerformPublish(ctx context.Context, req *PerformPublishRequest) error // PerformForget forgets a rooms history for a specific user PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error RemoveRoomAlias(ctx context.Context, req *RemoveRoomAliasRequest, res *RemoveRoomAliasResponse) error + SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) } type UserRoomserverAPI interface { + QuerySenderIDAPI QueryLatestEventsAndStateAPI KeyserverRoomserverAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error - PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error - PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error + PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) + PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error) } type FederationRoomserverAPI interface { + RestrictedJoinAPI InputRoomEventsAPI QueryLatestEventsAndStateAPI QueryBulkStateContentAPI + QuerySenderIDAPI + UserRoomPrivateKeyCreator + AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) + SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error + QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error + QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error - QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error + QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // which room to use by querying the first events roomID. @@ -191,19 +252,22 @@ type FederationRoomserverAPI interface { // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate // the state and auth chain to return. QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error - // Query if we think we're still in a room. - QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error // Query missing events for a room from roomserver QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error // Query whether a server is allowed to see an event - QueryServerAllowedToSeeEvent(ctx context.Context, serverName gomatrixserverlib.ServerName, eventID string) (allowed bool, err error) + QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error - QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error + QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error - PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error + HandleInvite(ctx context.Context, event *types.HeaderedEvent) error + + PerformInvite(ctx context.Context, req *PerformInviteRequest) error // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error + + IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) + StateQuerier() gomatrixserverlib.StateQuerier } type KeyserverRoomserverAPI interface { diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 88d5232704..8947ad624c 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -18,7 +18,9 @@ package api import ( "fmt" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Kind int @@ -66,9 +68,9 @@ type InputRoomEvent struct { // This controls how the event is processed. Kind Kind `json:"kind"` // The event JSON for the event to add. - Event *gomatrixserverlib.HeaderedEvent `json:"event"` + Event *types.HeaderedEvent `json:"event"` // Which server told us about this event. - Origin gomatrixserverlib.ServerName `json:"origin"` + Origin spec.ServerName `json:"origin"` // Whether the state is supplied as a list of event IDs or whether it // should be derived from the state at the previous events. HasState bool `json:"has_state"` @@ -94,9 +96,9 @@ type TransactionID struct { // InputRoomEventsRequest is a request to InputRoomEvents type InputRoomEventsRequest struct { - InputRoomEvents []InputRoomEvent `json:"input_room_events"` - Asynchronous bool `json:"async"` - VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` + InputRoomEvents []InputRoomEvent `json:"input_room_events"` + Asynchronous bool `json:"async"` + VirtualHost spec.ServerName `json:"virtual_host"` } // InputRoomEventsResponse is a response to InputRoomEvents diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 0c0f52c457..852b64206d 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -15,7 +15,9 @@ package api import ( + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // An OutputType is a type of roomserver output. @@ -106,7 +108,7 @@ const ( // prev_events. type OutputNewRoomEvent struct { // The Event. - Event *gomatrixserverlib.HeaderedEvent `json:"event"` + Event *types.HeaderedEvent `json:"event"` // Does the event completely rewrite the room state? If so, then AddsStateEventIDs // will contain the entire room state. RewritesState bool `json:"rewrites_state,omitempty"` @@ -169,8 +171,8 @@ type OutputNewRoomEvent struct { HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"` } -func (o *OutputNewRoomEvent) NeededStateEventIDs() ([]*gomatrixserverlib.HeaderedEvent, []string) { - addsStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, 1) +func (o *OutputNewRoomEvent) NeededStateEventIDs() ([]*types.HeaderedEvent, []string) { + addsStateEvents := make([]*types.HeaderedEvent, 0, 1) missingEventIDs := make([]string, 0, len(o.AddsStateEventIDs)) for _, eventID := range o.AddsStateEventIDs { if eventID != o.Event.EventID() { @@ -193,7 +195,7 @@ func (o *OutputNewRoomEvent) NeededStateEventIDs() ([]*gomatrixserverlib.Headere // should build their current room state up from OutputNewRoomEvents only. type OutputOldRoomEvent struct { // The Event. - Event *gomatrixserverlib.HeaderedEvent `json:"event"` + Event *types.HeaderedEvent `json:"event"` HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"` } @@ -204,7 +206,7 @@ type OutputNewInviteEvent struct { // The room version of the invited room. RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` // The "m.room.member" invite event. - Event *gomatrixserverlib.HeaderedEvent `json:"event"` + Event *types.HeaderedEvent `json:"event"` } // An OutputRetireInviteEvent is written whenever an existing invite is no longer @@ -213,8 +215,10 @@ type OutputNewInviteEvent struct { type OutputRetireInviteEvent struct { // The ID of the "m.room.member" invite event. EventID string - // The target user ID of the "m.room.member" invite event that was retired. - TargetUserID string + // The room ID of the "m.room.member" invite event. + RoomID string + // The target sender ID of the "m.room.member" invite event that was retired. + TargetSenderID spec.SenderID // Optional event ID of the event that replaced the invite. // This can be empty if the invite was rejected locally and we were unable // to reach the server that originally sent the invite. @@ -231,7 +235,7 @@ type OutputRedactedEvent struct { // The event ID that was redacted RedactedEventID string // The value of `unsigned.redacted_because` - the redaction event itself - RedactedBecause *gomatrixserverlib.HeaderedEvent + RedactedBecause *types.HeaderedEvent } // An OutputNewPeek is written whenever a user starts peeking into a room @@ -250,7 +254,7 @@ type OutputNewInboundPeek struct { // a race between tracking the state returned by /peek and emitting subsequent // peeked events) LatestEventID string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName // how often we told the peeking server to renew the peek RenewalInterval int64 } diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 83cb0460ae..b466b7ba8b 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -1,100 +1,48 @@ package api import ( + "crypto/ed25519" "encoding/json" - "fmt" - "net/http" + "time" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - - "github.com/matrix-org/dendrite/clientapi/jsonerror" ) -type PerformErrorCode int - -type PerformError struct { - Msg string - RemoteCode int // remote HTTP status code, for PerformErrRemote - Code PerformErrorCode -} - -func (p *PerformError) Error() string { - return fmt.Sprintf("%d : %s", p.Code, p.Msg) -} - -// JSONResponse maps error codes to suitable HTTP error codes, defaulting to 500. -func (p *PerformError) JSONResponse() util.JSONResponse { - switch p.Code { - case PerformErrorBadRequest: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(p.Msg), - } - case PerformErrorNoRoom: - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound(p.Msg), - } - case PerformErrorNotAllowed: - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(p.Msg), - } - case PerformErrorNoOperation: - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(p.Msg), - } - case PerformErrRemote: - // if the code is 0 then something bad happened and it isn't - // a remote HTTP error being encapsulated, e.g network error to remote. - if p.RemoteCode == 0 { - return util.ErrorResponse(fmt.Errorf("%s", p.Msg)) - } - return util.JSONResponse{ - Code: p.RemoteCode, - // TODO: Should we assert this is in fact JSON? E.g gjson parse? - JSON: json.RawMessage(p.Msg), - } - default: - return util.ErrorResponse(p) - } +type PerformCreateRoomRequest struct { + InvitedUsers []string + RoomName string + Visibility string + Topic string + StatePreset string + CreationContent json.RawMessage + InitialState []gomatrixserverlib.FledglingEvent + RoomAliasName string + RoomVersion gomatrixserverlib.RoomVersion + PowerLevelContentOverride json.RawMessage + IsDirect bool + + UserDisplayName string + UserAvatarURL string + KeyID gomatrixserverlib.KeyID + PrivateKey ed25519.PrivateKey + EventTime time.Time } -const ( - // PerformErrorNotAllowed means the user is not allowed to invite/join/etc this room (e.g join_rule:invite or banned) - PerformErrorNotAllowed PerformErrorCode = 1 - // PerformErrorBadRequest means the request was wrong in some way (invalid user ID, wrong server, etc) - PerformErrorBadRequest PerformErrorCode = 2 - // PerformErrorNoRoom means that the room being joined doesn't exist. - PerformErrorNoRoom PerformErrorCode = 3 - // PerformErrorNoOperation means that the request resulted in nothing happening e.g invite->invite or leave->leave. - PerformErrorNoOperation PerformErrorCode = 4 - // PerformErrRemote means that the request failed and the PerformError.Msg is the raw remote JSON error response - PerformErrRemote PerformErrorCode = 5 -) - type PerformJoinRequest struct { - RoomIDOrAlias string `json:"room_id_or_alias"` - UserID string `json:"user_id"` - IsGuest bool `json:"is_guest"` - Content map[string]interface{} `json:"content"` - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` - Unsigned map[string]interface{} `json:"unsigned"` -} - -type PerformJoinResponse struct { - // The room ID, populated on success. - RoomID string `json:"room_id"` - JoinedVia gomatrixserverlib.ServerName - // If non-nil, the join request failed. Contains more information why it failed. - Error *PerformError + RoomIDOrAlias string `json:"room_id_or_alias"` + UserID string `json:"user_id"` + IsGuest bool `json:"is_guest"` + Content map[string]interface{} `json:"content"` + ServerNames []spec.ServerName `json:"server_names"` + Unsigned map[string]interface{} `json:"unsigned"` } type PerformLeaveRequest struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` + RoomID string + Leaver spec.UserID } type PerformLeaveResponse struct { @@ -103,40 +51,18 @@ type PerformLeaveResponse struct { } type PerformInviteRequest struct { - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - Event *gomatrixserverlib.HeaderedEvent `json:"event"` - InviteRoomState []gomatrixserverlib.InviteV2StrippedState `json:"invite_room_state"` - SendAsServer string `json:"send_as_server"` - TransactionID *TransactionID `json:"transaction_id"` -} - -type PerformInviteResponse struct { - Error *PerformError + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + Event *types.HeaderedEvent `json:"event"` + InviteRoomState []gomatrixserverlib.InviteStrippedState `json:"invite_room_state"` + SendAsServer string `json:"send_as_server"` + TransactionID *TransactionID `json:"transaction_id"` } type PerformPeekRequest struct { - RoomIDOrAlias string `json:"room_id_or_alias"` - UserID string `json:"user_id"` - DeviceID string `json:"device_id"` - ServerNames []gomatrixserverlib.ServerName `json:"server_names"` -} - -type PerformPeekResponse struct { - // The room ID, populated on success. - RoomID string `json:"room_id"` - // If non-nil, the join request failed. Contains more information why it failed. - Error *PerformError -} - -type PerformUnpeekRequest struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` - DeviceID string `json:"device_id"` -} - -type PerformUnpeekResponse struct { - // If non-nil, the join request failed. Contains more information why it failed. - Error *PerformError + RoomIDOrAlias string `json:"room_id_or_alias"` + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + ServerNames []spec.ServerName `json:"server_names"` } // PerformBackfillRequest is a request to PerformBackfill. @@ -148,9 +74,9 @@ type PerformBackfillRequest struct { // The maximum number of events to retrieve. Limit int `json:"limit"` // The server interested in the events. - ServerName gomatrixserverlib.ServerName `json:"server_name"` + ServerName spec.ServerName `json:"server_name"` // Which virtual host are we doing this for? - VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` + VirtualHost spec.ServerName `json:"virtual_host"` } // PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. @@ -166,7 +92,7 @@ func (r *PerformBackfillRequest) PrevEventIDs() []string { // PerformBackfillResponse is a response to PerformBackfill. type PerformBackfillResponse struct { // Missing events, arbritrary order. - Events []*gomatrixserverlib.HeaderedEvent `json:"events"` + Events []*types.HeaderedEvent `json:"events"` HistoryVisibility gomatrixserverlib.HistoryVisibility `json:"history_visibility"` } @@ -177,17 +103,12 @@ type PerformPublishRequest struct { NetworkID string } -type PerformPublishResponse struct { - // If non-nil, the publish request failed. Contains more information why it failed. - Error *PerformError -} - type PerformInboundPeekRequest struct { - UserID string `json:"user_id"` - RoomID string `json:"room_id"` - PeekID string `json:"peek_id"` - ServerName gomatrixserverlib.ServerName `json:"server_name"` - RenewalInterval int64 `json:"renewal_interval"` + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + PeekID string `json:"peek_id"` + ServerName spec.ServerName `json:"server_name"` + RenewalInterval int64 `json:"renewal_interval"` } type PerformInboundPeekResponse struct { @@ -198,10 +119,10 @@ type PerformInboundPeekResponse struct { RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` // The current state and auth chain events. // The lists will be in an arbitrary order. - StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` - AuthChainEvents []*gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"` + StateEvents []*types.HeaderedEvent `json:"state_events"` + AuthChainEvents []*types.HeaderedEvent `json:"auth_chain_events"` // The event at which this state was captured - LatestEvent *gomatrixserverlib.HeaderedEvent `json:"latest_event"` + LatestEvent *types.HeaderedEvent `json:"latest_event"` } // PerformForgetRequest is a request to PerformForget @@ -211,50 +132,3 @@ type PerformForgetRequest struct { } type PerformForgetResponse struct{} - -type PerformRoomUpgradeRequest struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` -} - -type PerformRoomUpgradeResponse struct { - NewRoomID string - Error *PerformError -} - -type PerformAdminEvacuateRoomRequest struct { - RoomID string `json:"room_id"` -} - -type PerformAdminEvacuateRoomResponse struct { - Affected []string `json:"affected"` - Error *PerformError -} - -type PerformAdminEvacuateUserRequest struct { - UserID string `json:"user_id"` -} - -type PerformAdminEvacuateUserResponse struct { - Affected []string `json:"affected"` - Error *PerformError -} - -type PerformAdminPurgeRoomRequest struct { - RoomID string `json:"room_id"` -} - -type PerformAdminPurgeRoomResponse struct { - Error *PerformError `json:"error,omitempty"` -} - -type PerformAdminDownloadStateRequest struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` - ServerName gomatrixserverlib.ServerName `json:"server_name"` -} - -type PerformAdminDownloadStateResponse struct { - Error *PerformError `json:"error,omitempty"` -} diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 612c331567..b6140afd56 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -17,13 +17,17 @@ package api import ( + "context" "encoding/json" "fmt" "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" ) @@ -47,12 +51,12 @@ type QueryLatestEventsAndStateResponse struct { RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` // The latest events in the room. // These are used to set the prev_events when sending an event. - LatestEvents []gomatrixserverlib.EventReference `json:"latest_events"` + LatestEvents []string `json:"latest_events"` // The state events requested. // This list will be in an arbitrary order. // These are used to set the auth_events when sending an event. // These are used to check whether the event is allowed. - StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` + StateEvents []*types.HeaderedEvent `json:"state_events"` // The depth of the latest events. // This is one greater than the maximum depth of the latest events. // This is used to set the depth when sending an event. @@ -82,7 +86,7 @@ type QueryStateAfterEventsResponse struct { PrevEventsExist bool `json:"prev_events_exist"` // The state events requested. // This list will be in an arbitrary order. - StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` + StateEvents []*types.HeaderedEvent `json:"state_events"` } // QueryEventsByIDRequest is a request to QueryEventsByID @@ -103,15 +107,15 @@ type QueryEventsByIDResponse struct { // fails to read it from the database then it will fail // the entire request. // This list will be in an arbitrary order. - Events []*gomatrixserverlib.HeaderedEvent `json:"events"` + Events []*types.HeaderedEvent `json:"events"` } // QueryMembershipForUserRequest is a request to QueryMembership type QueryMembershipForUserRequest struct { // ID of the room to fetch membership from - RoomID string `json:"room_id"` + RoomID string // ID of the user for whom membership is requested - UserID string `json:"user_id"` + UserID spec.UserID } // QueryMembershipForUserResponse is a response to QueryMembership @@ -141,7 +145,7 @@ type QueryMembershipsForRoomRequest struct { // Optional - ID of the user sending the request, for checking if the // user is allowed to see the memberships. If not specified then all // room memberships will be returned. - Sender string `json:"sender"` + SenderID spec.SenderID `json:"sender"` } // QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom @@ -159,7 +163,7 @@ type QueryMembershipsForRoomResponse struct { type QueryServerJoinedToRoomRequest struct { // Server name of the server to find. If not specified, we will // default to checking if the local server is joined. - ServerName gomatrixserverlib.ServerName `json:"server_name"` + ServerName spec.ServerName `json:"server_name"` // ID of the room to see if we are still joined to RoomID string `json:"room_id"` } @@ -170,6 +174,8 @@ type QueryServerJoinedToRoomResponse struct { RoomExists bool `json:"room_exists"` // True if we still believe that the server is participating in the room IsInRoom bool `json:"is_in_room"` + // The roomversion if joined to room + RoomVersion gomatrixserverlib.RoomVersion } // QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent @@ -177,7 +183,7 @@ type QueryServerAllowedToSeeEventRequest struct { // The event ID to look up invites in. EventID string `json:"event_id"` // The server interested in the event - ServerName gomatrixserverlib.ServerName `json:"server_name"` + ServerName spec.ServerName `json:"server_name"` } // QueryServerAllowedToSeeEventResponse is a response to QueryServerAllowedToSeeEvent @@ -195,13 +201,13 @@ type QueryMissingEventsRequest struct { // Limit the number of events this query returns. Limit int `json:"limit"` // The server interested in the event - ServerName gomatrixserverlib.ServerName `json:"server_name"` + ServerName spec.ServerName `json:"server_name"` } // QueryMissingEventsResponse is a response to QueryMissingEvents type QueryMissingEventsResponse struct { // Missing events, arbritrary order. - Events []*gomatrixserverlib.HeaderedEvent `json:"events"` + Events []*types.HeaderedEvent `json:"events"` } // QueryStateAndAuthChainRequest is a request to QueryStateAndAuthChain @@ -235,8 +241,8 @@ type QueryStateAndAuthChainResponse struct { StateKnown bool `json:"state_known"` // The state and auth chain events that were requested. // The lists will be in an arbitrary order. - StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` - AuthChainEvents []*gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"` + StateEvents []*types.HeaderedEvent `json:"state_events"` + AuthChainEvents []*types.HeaderedEvent `json:"auth_chain_events"` // True if the queried event was rejected earlier. IsRejected bool `json:"is_rejected"` } @@ -268,7 +274,7 @@ type QueryAuthChainRequest struct { } type QueryAuthChainResponse struct { - AuthChain []*gomatrixserverlib.HeaderedEvent + AuthChain []*types.HeaderedEvent } type QuerySharedUsersRequest struct { @@ -326,7 +332,7 @@ type QueryCurrentStateRequest struct { } type QueryCurrentStateResponse struct { - StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent + StateEvents map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent } type QueryKnownUsersRequest struct { @@ -340,34 +346,14 @@ type QueryKnownUsersResponse struct { } type QueryServerBannedFromRoomRequest struct { - ServerName gomatrixserverlib.ServerName `json:"server_name"` - RoomID string `json:"room_id"` + ServerName spec.ServerName `json:"server_name"` + RoomID string `json:"room_id"` } type QueryServerBannedFromRoomResponse struct { Banned bool `json:"banned"` } -type QueryRestrictedJoinAllowedRequest struct { - UserID string `json:"user_id"` - RoomID string `json:"room_id"` -} - -type QueryRestrictedJoinAllowedResponse struct { - // True if the room membership is restricted by the join rule being set to "restricted" - Restricted bool `json:"restricted"` - // True if our local server is joined to all of the allowed rooms specified in the "allow" - // key of the join rule, false if we are missing from some of them and therefore can't - // reliably decide whether or not we can satisfy the join - Resident bool `json:"resident"` - // True if the restricted join is allowed because we found the membership in one of the - // allowed rooms from the join rule, false if not - Allowed bool `json:"allowed"` - // Contains the user ID of the selected user ID that has power to issue invites, this will - // get populated into the "join_authorised_via_users_server" content in the membership - AuthorisedVia string `json:"authorised_via,omitempty"` -} - // MarshalJSON stringifies the room ID and StateKeyTuple keys so they can be sent over the wire in HTTP API mode. func (r *QueryBulkStateContentResponse) MarshalJSON() ([]byte, error) { se := make(map[string]string) @@ -403,7 +389,7 @@ func (r *QueryBulkStateContentResponse) UnmarshalJSON(data []byte) error { // MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { - se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) + se := make(map[string]*types.HeaderedEvent, len(r.StateEvents)) for k, v := range r.StateEvents { // use 0x1F (unit separator) as the delimiter between type/state key, se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v @@ -412,12 +398,12 @@ func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { } func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { - res := make(map[string]*gomatrixserverlib.HeaderedEvent) + res := make(map[string]*types.HeaderedEvent) err := json.Unmarshal(data, &res) if err != nil { return err } - r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(res)) + r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(res)) for k, v := range res { fields := strings.Split(k, "\x1F") r.StateEvents[gomatrixserverlib.StateKeyTuple{ @@ -441,7 +427,7 @@ type QueryMembershipAtEventResponse struct { // Membership is a map from eventID to membership event. Events that // do not have known state will return a nil event, resulting in a "leave" membership // when calculating history visibility. - Membership map[string]*gomatrixserverlib.HeaderedEvent `json:"membership"` + Membership map[string]*types.HeaderedEvent `json:"membership"` } // QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a @@ -455,3 +441,65 @@ type QueryLeftUsersRequest struct { type QueryLeftUsersResponse struct { LeftUsers []string `json:"user_ids"` } + +type JoinRoomQuerier struct { + Roomserver RestrictedJoinAPI +} + +func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { + return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) +} + +func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) { + return rq.Roomserver.InvitePending(ctx, roomID, senderID) +} + +func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { + roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID) + if err != nil || roomInfo == nil || roomInfo.IsStub() { + return nil, err + } + + req := QueryServerJoinedToRoomRequest{ + ServerName: localServerName, + RoomID: roomID.String(), + } + res := QueryServerJoinedToRoomResponse{} + if err = rq.Roomserver.QueryServerJoinedToRoom(ctx, &req, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) + } + + userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + locallyJoinedUsers, err := rq.Roomserver.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID)) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + return &gomatrixserverlib.RestrictedRoomJoinInfo{ + LocalServerInRoom: res.RoomExists && res.IsInRoom, + UserJoinedToRoom: userJoinedToRoom, + JoinedUsers: locallyJoinedUsers, + }, nil +} + +type MembershipQuerier struct { + Roomserver FederationRoomserverAPI +} + +func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { + res := QueryMembershipForUserResponse{} + err := mq.Roomserver.QueryMembershipForSenderID(ctx, roomID, senderID, &res) + + membership := "" + if err == nil { + membership = res.Membership + } + return membership, err +} diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 5f74c7854a..2505a993b9 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -17,8 +17,10 @@ package api import ( "context" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -26,9 +28,9 @@ import ( // SendEvents to the roomserver The events are written with KindNew. func SendEvents( ctx context.Context, rsAPI InputRoomEventsAPI, - kind Kind, events []*gomatrixserverlib.HeaderedEvent, - virtualHost, origin gomatrixserverlib.ServerName, - sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, + kind Kind, events []*types.HeaderedEvent, + virtualHost, origin spec.ServerName, + sendAsServer spec.ServerName, txnID *TransactionID, async bool, ) error { ires := make([]InputRoomEvent, len(events)) @@ -49,11 +51,11 @@ func SendEvents( // marked as `true` in haveEventIDs. func SendEventWithState( ctx context.Context, rsAPI InputRoomEventsAPI, - virtualHost gomatrixserverlib.ServerName, kind Kind, - state *fclient.RespState, event *gomatrixserverlib.HeaderedEvent, - origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, + virtualHost spec.ServerName, kind Kind, + state gomatrixserverlib.StateResponse, event *types.HeaderedEvent, + origin spec.ServerName, haveEventIDs map[string]bool, async bool, ) error { - outliers := state.Events(event.RoomVersion) + outliers := gomatrixserverlib.LineariseStateResponse(event.Version(), state) ires := make([]InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if haveEventIDs[outlier.EventID()] { @@ -61,12 +63,12 @@ func SendEventWithState( } ires = append(ires, InputRoomEvent{ Kind: KindOutlier, - Event: outlier.Headered(event.RoomVersion), + Event: &types.HeaderedEvent{PDU: outlier}, Origin: origin, }) } - stateEvents := state.StateEvents.UntrustedEvents(event.RoomVersion) + stateEvents := state.GetStateEvents().UntrustedEvents(event.Version()) stateEventIDs := make([]string, len(stateEvents)) for i := range stateEvents { stateEventIDs[i] = stateEvents[i].EventID() @@ -93,7 +95,7 @@ func SendEventWithState( // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI InputRoomEventsAPI, - virtualHost gomatrixserverlib.ServerName, + virtualHost spec.ServerName, ires []InputRoomEvent, async bool, ) error { request := InputRoomEventsRequest{ @@ -102,14 +104,12 @@ func SendInputRoomEvents( VirtualHost: virtualHost, } var response InputRoomEventsResponse - if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { - return err - } + rsAPI.InputRoomEvents(ctx, &request, &response) return response.Err() } // GetEvent returns the event or nil, even on errors. -func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *gomatrixserverlib.HeaderedEvent { +func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *types.HeaderedEvent { var res QueryEventsByIDResponse err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{ RoomID: roomID, @@ -126,7 +126,7 @@ func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) } // GetStateEvent returns the current state event in the room or nil. -func GetStateEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent { +func GetStateEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *types.HeaderedEvent { var res QueryCurrentStateResponse err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{ RoomID: roomID, @@ -144,7 +144,7 @@ func GetStateEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID string, tup } // IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs. -func IsServerBannedFromRoom(ctx context.Context, rsAPI FederationRoomserverAPI, roomID string, serverName gomatrixserverlib.ServerName) bool { +func IsServerBannedFromRoom(ctx context.Context, rsAPI FederationRoomserverAPI, roomID string, serverName spec.ServerName) bool { req := &QueryServerBannedFromRoomRequest{ ServerName: serverName, RoomID: roomID, @@ -163,11 +163,11 @@ func IsServerBannedFromRoom(ctx context.Context, rsAPI FederationRoomserverAPI, func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI QueryBulkStateContentAPI) ([]fclient.PublicRoom, error) { avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} - canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} + canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCanonicalAlias, StateKey: ""} topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""} guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""} - visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} - joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} + visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomHistoryVisibility, StateKey: ""} + joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""} var stateRes QueryBulkStateContentResponse err := rsAPI.QueryBulkStateContent(ctx, &QueryBulkStateContentRequest{ @@ -175,7 +175,7 @@ func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI QueryBulkS AllowWildcards: true, StateTuples: []gomatrixserverlib.StateKeyTuple{ nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple, - {EventType: gomatrixserverlib.MRoomMember, StateKey: "*"}, + {EventType: spec.MRoomMember, StateKey: "*"}, }, }, &stateRes) if err != nil { @@ -191,7 +191,7 @@ func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI QueryBulkS joinCount := 0 var joinRule, guestAccess string for tuple, contentVal := range data { - if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" { + if tuple.EventType == spec.MRoomMember && contentVal == "join" { joinCount++ continue } @@ -215,7 +215,7 @@ func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI QueryBulkS guestAccess = contentVal } } - if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" { + if joinRule == spec.Public && guestAccess == "can_join" { pub.GuestCanJoin = true } pub.JoinedMembersCount = joinCount diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index 31a856e8e2..df95851e33 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -13,7 +13,11 @@ package auth import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // TODO: This logic should live in gomatrixserverlib @@ -21,9 +25,10 @@ import ( // IsServerAllowed returns true if the server is allowed to see events in the room // at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87 func IsServerAllowed( - serverName gomatrixserverlib.ServerName, + ctx context.Context, querier api.QuerySenderIDAPI, + serverName spec.ServerName, serverCurrentlyInRoom bool, - authEvents []*gomatrixserverlib.Event, + authEvents []gomatrixserverlib.PDU, ) bool { historyVisibility := HistoryVisibilityForRoom(authEvents) @@ -32,7 +37,7 @@ func IsServerAllowed( return true } // 2. If the user's membership was join, allow. - joinedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, gomatrixserverlib.Join) + joinedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Join) if joinedUserExists { return true } @@ -41,7 +46,7 @@ func IsServerAllowed( return true } // 4. If the user's membership was invite, and the history_visibility was set to invited, allow. - invitedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, gomatrixserverlib.Invite) + invitedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Invite) if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited { return true } @@ -50,12 +55,12 @@ func IsServerAllowed( return false } -func HistoryVisibilityForRoom(authEvents []*gomatrixserverlib.Event) gomatrixserverlib.HistoryVisibility { +func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserverlib.HistoryVisibility { // https://matrix.org/docs/spec/client_server/r0.6.0#id87 // By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared. visibility := gomatrixserverlib.HistoryVisibilityShared for _, ev := range authEvents { - if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility { + if ev.Type() != spec.MRoomHistoryVisibility { continue } if vis, err := ev.HistoryVisibility(); err == nil { @@ -65,9 +70,9 @@ func HistoryVisibilityForRoom(authEvents []*gomatrixserverlib.Event) gomatrixser return visibility } -func IsAnyUserOnServerWithMembership(serverName gomatrixserverlib.ServerName, authEvents []*gomatrixserverlib.Event, wantMembership string) bool { +func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySenderIDAPI, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { for _, ev := range authEvents { - if ev.Type() != gomatrixserverlib.MRoomMember { + if ev.Type() != spec.MRoomMember { continue } membership, err := ev.Membership() @@ -80,12 +85,16 @@ func IsAnyUserOnServerWithMembership(serverName gomatrixserverlib.ServerName, au continue } - _, domain, err := gomatrixserverlib.SplitID('@', *stateKey) + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + continue + } + userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey)) if err != nil { continue } - if domain == serverName { + if userID.Domain() == serverName { return true } } diff --git a/roomserver/auth/auth_test.go b/roomserver/auth/auth_test.go new file mode 100644 index 0000000000..058361e6ed --- /dev/null +++ b/roomserver/auth/auth_test.go @@ -0,0 +1,95 @@ +package auth + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +type FakeQuerier struct { + api.QuerySenderIDAPI +} + +func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + +func TestIsServerAllowed(t *testing.T) { + alice := test.NewUser(t) + + tests := []struct { + name string + want bool + roomFunc func() *test.Room + serverName spec.ServerName + serverCurrentlyInRoom bool + }{ + { + name: "no servername specified", + roomFunc: func() *test.Room { return test.NewRoom(t, alice) }, + }, + { + name: "no authEvents specified", + serverName: "test", + roomFunc: func() *test.Room { return &test.Room{} }, + }, + { + name: "default denied", + serverName: "test2", + roomFunc: func() *test.Room { return test.NewRoom(t, alice) }, + }, + { + name: "world readable room", + serverName: "test", + roomFunc: func() *test.Room { + return test.NewRoom(t, alice, test.RoomHistoryVisibility(gomatrixserverlib.HistoryVisibilityWorldReadable)) + }, + want: true, + }, + { + name: "allowed due to alice being joined", + serverName: "test", + roomFunc: func() *test.Room { return test.NewRoom(t, alice) }, + want: true, + }, + { + name: "allowed due to 'serverCurrentlyInRoom'", + serverName: "test2", + roomFunc: func() *test.Room { return test.NewRoom(t, alice) }, + want: true, + serverCurrentlyInRoom: true, + }, + { + name: "allowed due to pending invite", + serverName: "test2", + roomFunc: func() *test.Room { + bob := test.User{ID: "@bob:test2"} + r := test.NewRoom(t, alice, test.RoomHistoryVisibility(gomatrixserverlib.HistoryVisibilityInvited)) + r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ + "membership": spec.Invite, + }, test.WithStateKey(bob.ID)) + return r + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.roomFunc == nil { + t.Fatalf("missing roomFunc") + } + var authEvents []gomatrixserverlib.PDU + for _, ev := range tt.roomFunc().Events() { + authEvents = append(authEvents, ev.PDU) + } + + if got := IsServerAllowed(context.Background(), &FakeQuerier{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { + t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 94b8b16cf9..b04a56fe81 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -25,7 +25,9 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -111,17 +113,14 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID( return nil } +// nolint:gocyclo // RemoveRoomAlias implements alias.RoomserverInternalAPI +// nolint: gocyclo func (r *RoomserverInternalAPI) RemoveRoomAlias( ctx context.Context, request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { - _, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.UserID) - if err != nil { - return err - } - roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) @@ -132,17 +131,28 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return nil } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + + sender, err := r.QueryUserIDForSender(ctx, *validRoomID, request.SenderID) + if err != nil || sender == nil { + return fmt.Errorf("r.QueryUserIDForSender: %w", err) + } + virtualHost := sender.Domain() + response.Found = true creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err) } - if creatorID != request.UserID { - var plEvent *gomatrixserverlib.HeaderedEvent + if spec.SenderID(creatorID) != request.SenderID { + var plEvent *types.HeaderedEvent var pls *gomatrixserverlib.PowerLevelContent - plEvent, err = r.DB.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomPowerLevels, "") + plEvent, err = r.DB.GetStateEvent(ctx, roomID, spec.MRoomPowerLevels, "") if err != nil { return fmt.Errorf("r.DB.GetStateEvent: %w", err) } @@ -152,13 +162,13 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return fmt.Errorf("plEvent.PowerLevels: %w", err) } - if pls.UserLevel(request.UserID) < pls.EventLevel(gomatrixserverlib.MRoomCanonicalAlias, true) { + if pls.UserLevel(request.SenderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { response.Removed = false return nil } } - ev, err := r.DB.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomCanonicalAlias, "") + ev, err := r.DB.GetStateEvent(ctx, roomID, spec.MRoomCanonicalAlias, "") if err != nil && err != sql.ErrNoRows { return err } else if ev != nil { @@ -170,30 +180,33 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - sender := request.UserID - if request.UserID != ev.Sender() { - sender = ev.Sender() + senderID := request.SenderID + if request.SenderID != ev.SenderID() { + senderID = ev.SenderID() + } + sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID) + if err != nil || sender == nil { + return err } - _, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender) + validRoomID, err := spec.NewRoomID(roomID) if err != nil { return err } - - identity, err := r.Cfg.Global.SigningIdentityFor(senderDomain) + identity, err := r.SigningIdentityFor(ctx, *validRoomID, *sender) if err != nil { return err } - builder := &gomatrixserverlib.EventBuilder{ - Sender: sender, + proto := &gomatrixserverlib.ProtoEvent{ + SenderID: string(senderID), RoomID: ev.RoomID(), Type: ev.Type(), StateKey: ev.StateKey(), Content: res, } - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) + eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto) if err != nil { return fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) } @@ -202,16 +215,16 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( } stateRes := &api.QueryLatestEventsAndStateResponse{} - if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil { + if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil { return err } - newEvent, err := eventutil.BuildEvent(ctx, builder, &r.Cfg.Global, identity, time.Now(), &eventsNeeded, stateRes) + newEvent, err := eventutil.BuildEvent(ctx, proto, &identity, time.Now(), &eventsNeeded, stateRes) if err != nil { return err } - err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false) + err = api.SendEvents(ctx, r, api.KindNew, []*types.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false) if err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 7ca3675da6..984dc7d9b2 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -2,9 +2,13 @@ package internal import ( "context" + "crypto/ed25519" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -18,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -39,11 +44,12 @@ type RoomserverInternalAPI struct { *perform.Forgetter *perform.Upgrader *perform.Admin + *perform.Creator ProcessContext *process.ProcessContext DB storage.Database Cfg *config.Dendrite Cache caching.RoomServerCaches - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName KeyRing gomatrixserverlib.JSONVerifier ServerACLs *acls.ServerACLs fsAPI fsAPI.RoomserverFederationAPI @@ -53,7 +59,7 @@ type RoomserverInternalAPI struct { Durable string InputRoomEventTopic string // JetStream topic for new input room events OutputProducer *producers.RoomEventProducer - PerspectiveServerNames []gomatrixserverlib.ServerName + PerspectiveServerNames []spec.ServerName enableMetrics bool } @@ -61,7 +67,7 @@ func NewRoomserverAPI( processContext *process.ProcessContext, dendriteCfg *config.Dendrite, roomserverDB storage.Database, js nats.JetStreamContext, nc *nats.Conn, caches caching.RoomServerCaches, enableMetrics bool, ) *RoomserverInternalAPI { - var perspectiveServerNames []gomatrixserverlib.ServerName + var perspectiveServerNames []spec.ServerName for _, kp := range dendriteCfg.FederationAPI.KeyPerspectives { perspectiveServerNames = append(perspectiveServerNames, kp.ServerName) } @@ -90,6 +96,7 @@ func NewRoomserverAPI( Cache: caches, IsLocalServerName: dendriteCfg.Global.IsLocalServerName, ServerACLs: serverACLs, + Cfg: dendriteCfg, }, enableMetrics: enableMetrics, // perform-er structs get initialised when we have a federation sender to use @@ -104,11 +111,6 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.fsAPI = fsAPI r.KeyRing = keyRing - identity, err := r.Cfg.Global.SigningIdentityFor(r.ServerName) - if err != nil { - logrus.Panic(err) - } - r.Inputer = &input.Inputer{ Cfg: &r.Cfg.RoomServer, ProcessContext: r.ProcessContext, @@ -119,16 +121,18 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio NATSClient: r.NATSClient, Durable: nats.Durable(r.Durable), ServerName: r.ServerName, - SigningIdentity: identity, + SigningIdentity: r.SigningIdentityFor, FSAPI: fsAPI, KeyRing: keyRing, ACLs: r.ServerACLs, Queryer: r.Queryer, + EnableMetrics: r.enableMetrics, } r.Inviter = &perform.Inviter{ DB: r.DB, Cfg: &r.Cfg.RoomServer, FSAPI: r.fsAPI, + RSAPI: r, Inputer: r.Inputer, } r.Joiner = &perform.Joiner{ @@ -160,6 +164,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Cfg: &r.Cfg.RoomServer, DB: r.DB, FSAPI: r.fsAPI, + RSAPI: r, Inputer: r.Inputer, } r.Publisher = &perform.Publisher{ @@ -169,6 +174,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio IsLocalServerName: r.Cfg.Global.IsLocalServerName, DB: r.DB, FSAPI: r.fsAPI, + Querier: r.Queryer, KeyRing: r.KeyRing, // Perspective servers are trusted to not lie about server keys, so we will also // prefer these servers when backfilling (assuming they are in the room) rather @@ -189,6 +195,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Queryer: r.Queryer, Leaver: r.Leaver, } + r.Creator = &perform.Creator{ + DB: r.DB, + Cfg: &r.Cfg.RoomServer, + RSAPI: r, + } if err := r.Inputer.Start(); err != nil { logrus.WithError(err).Panic("failed to start roomserver input API") @@ -204,20 +215,35 @@ func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalA r.asAPI = asAPI } -func (r *RoomserverInternalAPI) PerformInvite( - ctx context.Context, - req *api.PerformInviteRequest, - res *api.PerformInviteResponse, +func (r *RoomserverInternalAPI) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) { + return r.Inviter.IsKnownRoom(ctx, roomID) +} + +func (r *RoomserverInternalAPI) StateQuerier() gomatrixserverlib.StateQuerier { + return r.Inviter.StateQuerier() +} + +func (r *RoomserverInternalAPI) HandleInvite( + ctx context.Context, inviteEvent *types.HeaderedEvent, ) error { - outputEvents, err := r.Inviter.PerformInvite(ctx, req, res) + outputEvents, err := r.Inviter.ProcessInviteMembership(ctx, inviteEvent) if err != nil { - sentry.CaptureException(err) return err } - if len(outputEvents) == 0 { - return nil - } - return r.OutputProducer.ProduceRoomEvents(req.Event.RoomID(), outputEvents) + return r.OutputProducer.ProduceRoomEvents(inviteEvent.RoomID(), outputEvents) +} + +func (r *RoomserverInternalAPI) PerformCreateRoom( + ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest, +) (string, *util.JSONResponse) { + return r.Creator.PerformCreateRoom(ctx, userID, roomID, createRequest) +} + +func (r *RoomserverInternalAPI) PerformInvite( + ctx context.Context, + req *api.PerformInviteRequest, +) error { + return r.Inviter.PerformInvite(ctx, req) } func (r *RoomserverInternalAPI) PerformLeave( @@ -243,3 +269,65 @@ func (r *RoomserverInternalAPI) PerformForget( ) error { return r.Forgetter.PerformForget(ctx, req, resp) } + +// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. +func (r *RoomserverInternalAPI) GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) { + key, err := r.DB.SelectUserRoomPrivateKey(ctx, userID, roomID) + if err != nil { + return nil, err + } + // no key found, create one + if len(key) == 0 { + _, key, err = ed25519.GenerateKey(nil) + if err != nil { + return nil, err + } + key, err = r.DB.InsertUserRoomPrivatePublicKey(ctx, userID, roomID, key) + if err != nil { + return nil, err + } + } + return key, nil +} + +func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error { + pubKeyBytes, err := senderID.RawBytes() + if err != nil { + return err + } + _, err = r.DB.InsertUserRoomPublicKey(ctx, userID, roomID, ed25519.PublicKey(pubKeyBytes)) + return err +} + +func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) { + roomVersion, ok := r.Cache.GetRoomVersion(roomID.String()) + if !ok { + roomInfo, err := r.DB.RoomInfo(ctx, roomID.String()) + if err != nil { + return fclient.SigningIdentity{}, err + } + if roomInfo != nil { + roomVersion = roomInfo.RoomVersion + } + } + if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + privKey, err := r.GetOrCreateUserRoomPrivateKey(ctx, senderID, roomID) + if err != nil { + return fclient.SigningIdentity{}, err + } + return fclient.SigningIdentity{ + PrivateKey: privKey, + KeyID: "ed25519:1", + ServerName: "self", + }, nil + } + identity, err := r.Cfg.Global.SigningIdentityFor(senderID.Domain()) + if err != nil { + return fclient.SigningIdentity{}, err + } + return *identity, err +} + +func (r *RoomserverInternalAPI) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) { + return r.DB.AssignRoomNID(ctx, roomID, roomVersion) +} diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 9defe79451..89fae244fa 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -20,7 +20,9 @@ import ( "sort" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" @@ -33,8 +35,9 @@ func CheckForSoftFail( ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, - event *gomatrixserverlib.HeaderedEvent, + event *types.HeaderedEvent, stateEventIDs []string, + querier api.QuerySenderIDAPI, ) (bool, error) { rewritesState := len(stateEventIDs) > 1 @@ -48,7 +51,7 @@ func CheckForSoftFail( } else { // Then get the state entries for the current state snapshot. // We'll use this to check if the event is allowed right now. - roomState := state.NewStateResolution(db, roomInfo) + roomState := state.NewStateResolution(db, roomInfo, querier) authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID()) if err != nil { return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err) @@ -59,36 +62,39 @@ func CheckForSoftFail( // state because we haven't received a m.room.create event yet. // If we're now processing the first create event then never // soft-fail it. - if len(authStateEntries) == 0 && event.Type() == gomatrixserverlib.MRoomCreate { + if len(authStateEntries) == 0 && event.Type() == spec.MRoomCreate { return false, nil } // Work out which of the state events we actually need. - stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) + stateNeeded := gomatrixserverlib.StateNeededForAuth( + []gomatrixserverlib.PDU{event.PDU}, + ) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomVersion, stateNeeded, authStateEntries) if err != nil { return true, fmt.Errorf("loadAuthEvents: %w", err) } // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return querier.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { // return true, nil return true, err } return false, nil } -// CheckAuthEvents checks that the event passes authentication checks -// Returns the numeric IDs for the auth events. -func CheckAuthEvents( +// GetAuthEvents returns the numeric IDs for the auth events. +func GetAuthEvents( ctx context.Context, db storage.RoomDatabase, - roomInfo *types.RoomInfo, - event *gomatrixserverlib.HeaderedEvent, + roomVersion gomatrixserverlib.RoomVersion, + event gomatrixserverlib.PDU, authEventIDs []string, -) ([]types.EventNID, error) { +) (gomatrixserverlib.AuthEventProvider, error) { // Grab the numeric IDs for the supplied auth state events from the database. authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs, true) if err != nil { @@ -97,25 +103,14 @@ func CheckAuthEvents( authStateEntries = types.DeduplicateStateEntries(authStateEntries) // Work out which of the state events we actually need. - stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) + stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomVersion, stateNeeded, authStateEntries) if err != nil { return nil, fmt.Errorf("loadAuthEvents: %w", err) } - - // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil { - return nil, err - } - - // Return the numeric IDs for the auth events. - result := make([]types.EventNID, len(authStateEntries)) - for i := range authStateEntries { - result[i] = authStateEntries[i].EventNID - } - return result, nil + return &authEvents, nil } type authEvents struct { @@ -131,31 +126,31 @@ func (ae *authEvents) Valid() bool { } // Create implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) { +func (ae *authEvents) Create() (gomatrixserverlib.PDU, error) { return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil } // PowerLevels implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) { +func (ae *authEvents) PowerLevels() (gomatrixserverlib.PDU, error) { return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil } // JoinRules implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) { +func (ae *authEvents) JoinRules() (gomatrixserverlib.PDU, error) { return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil } // Memmber implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) { - return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil +func (ae *authEvents) Member(stateKey spec.SenderID) (gomatrixserverlib.PDU, error) { + return ae.lookupEvent(types.MRoomMemberNID, string(stateKey)), nil } // ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) { +func (ae *authEvents) ThirdPartyInvite(stateKey string) (gomatrixserverlib.PDU, error) { return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil } -func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event { +func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) gomatrixserverlib.PDU { eventNID, ok := ae.state.lookup(types.StateKeyTuple{ EventTypeNID: typeNID, EventStateKeyNID: types.EmptyStateKeyNID, @@ -167,10 +162,10 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) * if !ok { return nil } - return event.Event + return event.PDU } -func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event { +func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) gomatrixserverlib.PDU { stateKeyNID, ok := ae.stateKeyNIDMap[stateKey] if !ok { return nil @@ -186,14 +181,14 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * if !ok { return nil } - return event.Event + return event.PDU } // loadAuthEvents loads the events needed for authentication from the supplied room state. func loadAuthEvents( ctx context.Context, db state.StateResolutionStorage, - roomInfo *types.RoomInfo, + roomVersion gomatrixserverlib.RoomVersion, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { @@ -216,7 +211,8 @@ func loadAuthEvents( eventNIDs = append(eventNIDs, eventNID) } } - if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil { + + if result.events, err = db.Events(ctx, roomVersion, eventNIDs); err != nil { return } roomID := "" diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 9a70bcc9c3..febabf4114 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -6,9 +6,9 @@ import ( "errors" "fmt" "sort" - "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/roomserver/api" @@ -44,7 +44,7 @@ func UpdateToInviteMembership( updates = append(updates, api.OutputEvent{ Type: api.OutputTypeNewInviteEvent, NewInviteEvent: &api.OutputNewInviteEvent{ - Event: add.Headered(roomVersion), + Event: &types.HeaderedEvent{PDU: add.PDU}, RoomVersion: roomVersion, }, }) @@ -54,9 +54,10 @@ func UpdateToInviteMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, - Membership: gomatrixserverlib.Join, + RoomID: add.RoomID(), + Membership: spec.Join, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } @@ -67,7 +68,7 @@ func UpdateToInviteMembership( // memberships. If the servername is not supplied then the local server will be // checked instead using a faster code path. // TODO: This should probably be replaced by an API call. -func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { +func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, serverName spec.ServerName, roomID string) (bool, error) { info, err := db.RoomInfo(ctx, roomID) if err != nil { return false, err @@ -85,21 +86,21 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam return false, err } - events, err := db.Events(ctx, info, eventNIDs) + events, err := db.Events(ctx, info.RoomVersion, eventNIDs) if err != nil { return false, err } - gmslEvents := make([]*gomatrixserverlib.Event, len(events)) + gmslEvents := make([]gomatrixserverlib.PDU, len(events)) for i := range events { - gmslEvents[i] = events[i].Event + gmslEvents[i] = events[i].PDU } - return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil + return auth.IsAnyUserOnServerWithMembership(ctx, querier, serverName, gmslEvents, spec.Join), nil } func IsInvitePending( ctx context.Context, db storage.Database, - roomID, userID string, -) (bool, string, string, *gomatrixserverlib.Event, error) { + roomID string, senderID spec.SenderID, +) (bool, spec.SenderID, string, gomatrixserverlib.PDU, error) { // Look up the room NID for the supplied room ID. info, err := db.RoomInfo(ctx, roomID) if err != nil { @@ -110,13 +111,13 @@ func IsInvitePending( } // Look up the state key NID for the supplied user ID. - targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID}) + targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{string(senderID)}) if err != nil { return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) } - targetUserNID, targetUserFound := targetUserNIDs[userID] + targetUserNID, targetUserFound := targetUserNIDs[string(senderID)] if !targetUserFound { - return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) + return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", senderID, targetUserNIDs) } // Let's see if we have an event active for the user in the room. If @@ -148,9 +149,14 @@ func IsInvitePending( return false, "", "", nil, fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers) } - event, err := gomatrixserverlib.NewEventFromTrustedJSON(eventJSON, false, info.RoomVersion) + verImpl, err := gomatrixserverlib.GetRoomVersion(info.RoomVersion) + if err != nil { + return false, "", "", nil, err + } + + event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false) - return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err + return true, spec.SenderID(senderUser), userNIDToEventID[senderUserNIDs[0]], event, err } // GetMembershipsAtState filters the state events to @@ -177,7 +183,10 @@ func GetMembershipsAtState( util.Unique(eventNIDs) // Get all of the events in this state - stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) + if roomInfo == nil { + return nil, types.ErrorInvalidRoomInfo + } + stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs) if err != nil { return nil, err } @@ -194,7 +203,7 @@ func GetMembershipsAtState( return nil, err } - if membership == gomatrixserverlib.Join { + if membership == spec.Join { events = append(events, event) } } @@ -202,8 +211,8 @@ func GetMembershipsAtState( return events, nil } -func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) +func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID, querier api.QuerySenderIDAPI) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info, querier) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { @@ -220,30 +229,33 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } -func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) +func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID, querier api.QuerySenderIDAPI) (map[string][]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info, querier) // Fetch the state as it was when this event was fired return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) } func LoadEvents( ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID, -) ([]*gomatrixserverlib.Event, error) { - stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) +) ([]gomatrixserverlib.PDU, error) { + if roomInfo == nil { + return nil, types.ErrorInvalidRoomInfo + } + stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs) if err != nil { return nil, err } - result := make([]*gomatrixserverlib.Event, len(stateEvents)) + result := make([]gomatrixserverlib.PDU, len(stateEvents)) for i := range stateEvents { - result[i] = stateEvents[i].Event + result[i] = stateEvents[i].PDU } return result, nil } func LoadStateEvents( ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, -) ([]*gomatrixserverlib.Event, error) { +) ([]gomatrixserverlib.PDU, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) for i := range stateEntries { eventNIDs[i] = stateEntries[i].EventNID @@ -252,7 +264,7 @@ func LoadStateEvents( } func CheckServerAllowedToSeeEvent( - ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, querier api.QuerySenderIDAPI, ) (bool, error) { stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) switch err { @@ -261,7 +273,7 @@ func CheckServerAllowedToSeeEvent( case tables.OptimisationNotSupportedError: // The database engine didn't support this optimisation, so fall back to using // the old and slow method - stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName) + stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName, querier) if err != nil { return false, err } @@ -276,13 +288,13 @@ func CheckServerAllowedToSeeEvent( return false, err } } - return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil + return auth.IsServerAllowed(ctx, querier, serverName, isServerInRoom, stateAtEvent), nil } func slowGetHistoryVisibilityState( - ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, -) ([]*gomatrixserverlib.Event, error) { - roomState := state.NewStateResolution(db, info) + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, querier api.QuerySenderIDAPI, +) ([]gomatrixserverlib.PDU, error) { + roomState := state.NewStateResolution(db, info, querier) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -306,9 +318,18 @@ func slowGetHistoryVisibilityState( // If the event state key doesn't match the given servername // then we'll filter it out. This does preserve state keys that // are "" since these will contain history visibility etc. + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } for nid, key := range stateKeys { - if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { - delete(stateKeys, nid) + if key != "" { + userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key)) + if err == nil && userID != nil { + if userID.Domain() != serverName { + delete(stateKeys, nid) + } + } } } @@ -332,7 +353,7 @@ func slowGetHistoryVisibilityState( // TODO: Remove this when we have tests to assert correctness of this function func ScanEventTree( ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, querier api.QuerySenderIDAPI, ) ([]types.EventNID, map[string]struct{}, error) { var resultNIDs []types.EventNID var err error @@ -375,7 +396,7 @@ BFSLoop: // It's nasty that we have to extract the room ID from an event, but many federation requests // only talk in event IDs, no room IDs at all (!!!) ev := events[0] - isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID()) + isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID()) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") } @@ -398,7 +419,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom) + allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", @@ -427,7 +448,7 @@ BFSLoop: } func QueryLatestEventsAndState( - ctx context.Context, db storage.Database, + ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { @@ -440,7 +461,7 @@ func QueryLatestEventsAndState( return nil } - roomState := state.NewStateResolution(db, roomInfo) + roomState := state.NewStateResolution(db, roomInfo, querier) response.RoomExists = true response.RoomVersion = roomInfo.RoomVersion @@ -473,7 +494,7 @@ func QueryLatestEventsAndState( } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion)) + response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event}) } return nil diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index dd74b844a0..1cef83df75 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -8,6 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/roomserver/types" @@ -41,7 +42,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { var authNIDs []types.EventNID for _, x := range room.Events() { - roomInfo, err := db.GetOrCreateRoomInfo(context.Background(), x.Unwrap()) + roomInfo, err := db.GetOrCreateRoomInfo(context.Background(), x.PDU) assert.NoError(t, err) assert.NotNil(t, roomInfo) @@ -52,18 +53,18 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey()) assert.NoError(t, err) - evNID, _, err := db.StoreEvent(context.Background(), x.Event, roomInfo, eventTypeNID, eventStateKeyNID, authNIDs, false) + evNID, _, err := db.StoreEvent(context.Background(), x.PDU, roomInfo, eventTypeNID, eventStateKeyNID, authNIDs, false) assert.NoError(t, err) authNIDs = append(authNIDs, evNID) } // Alice should have no pending invites and should have a NID - pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID) + pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, spec.SenderID(alice.ID)) assert.NoError(t, err, "failed to get pending invites") assert.False(t, pendingInvite, "unexpected pending invite") // Bob should have no pending invites and receive a new NID - pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID) + pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, spec.SenderID(bob.ID)) assert.NoError(t, err, "failed to get pending invites") assert.False(t, pendingInvite, "unexpected pending invite") }) diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 83aa9e90fe..a8afbc3135 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -25,6 +25,7 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" @@ -79,8 +80,8 @@ type Inputer struct { NATSClient *nats.Conn JetStream nats.JetStreamContext Durable nats.SubOpt - ServerName gomatrixserverlib.ServerName - SigningIdentity *fclient.SigningIdentity + ServerName spec.ServerName + SigningIdentity func(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) FSAPI fedapi.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs @@ -90,7 +91,7 @@ type Inputer struct { Queryer *query.Queryer UserAPI userapi.RoomserverUserAPI - enableMetrics bool + EnableMetrics bool } // If a room consumer is inactive for a while then we will allow NATS @@ -177,7 +178,7 @@ func (r *Inputer) startWorkerForRoom(roomID string) { // will look to see if we have a worker for that room which has its // own consumer. If we don't, we'll start one. func (r *Inputer) Start() error { - if r.enableMetrics { + if r.EnableMetrics { prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration) } _, err := r.JetStream.Subscribe( @@ -284,7 +285,7 @@ func (w *worker) _next() { var errString string if err = w.r.processRoomEvent( w.r.ProcessContext.Context(), - gomatrixserverlib.ServerName(msg.Header.Get("virtual_host")), + spec.ServerName(msg.Header.Get("virtual_host")), &inputRoomEvent, ); err != nil { switch err.(type) { @@ -388,18 +389,18 @@ func (r *Inputer) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) error { +) { // Queue up the event into the roomserver. replySub, err := r.queueInputRoomEvents(ctx, request) if err != nil { response.ErrMsg = err.Error() - return nil + return } // If we aren't waiting for synchronous responses then we can // give up here, there is nothing further to do. if replySub == nil { - return nil + return } // Otherwise, we'll want to sit and wait for the responses @@ -411,14 +412,12 @@ func (r *Inputer) InputRoomEvents( msg, err := replySub.NextMsgWithContext(ctx) if err != nil { response.ErrMsg = err.Error() - return nil + return } if len(msg.Data) > 0 { response.ErrMsg = string(msg.Data) } } - - return nil } var roomserverInputBackpressure = prometheus.NewGaugeVec( diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 971befa076..db3c955025 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" @@ -73,7 +74,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, - virtualHost gomatrixserverlib.ServerName, + virtualHost spec.ServerName, input *api.InputRoomEvent, ) error { select { @@ -101,7 +102,7 @@ func (r *Inputer) processRoomEvent( // Parse and validate the event JSON headered := input.Event - event := headered.Unwrap() + event := headered.PDU logger := util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": event.EventID(), "room_id": event.RoomID(), @@ -123,13 +124,21 @@ func (r *Inputer) processRoomEvent( if rerr != nil { return fmt.Errorf("r.DB.RoomInfo: %w", rerr) } - isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") + isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("") if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()) + validRoomID, err := spec.NewRoomID(event.RoomID()) if err != nil { - return fmt.Errorf("event has invalid sender %q", input.Event.Sender()) + return err + } + sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + if err != nil { + return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) + } + senderDomain := spec.ServerName("") + if sender != nil { + senderDomain = sender.Domain() } // If we already know about this outlier and it hasn't been rejected @@ -180,7 +189,7 @@ func (r *Inputer) processRoomEvent( // Sort all of the servers into a map so that we can randomise // their order. Then make sure that the input origin and the // event origin are first on the list. - servers := map[gomatrixserverlib.ServerName]struct{}{} + servers := map[spec.ServerName]struct{}{} for _, server := range serverRes.ServerNames { servers[server] = struct{}{} } @@ -192,7 +201,9 @@ func (r *Inputer) processRoomEvent( serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) delete(servers, input.Origin) } - if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { + // Only perform this check if the sender mxid_mapping can be resolved. + // Don't fail processing the event if we have no mxid_maping. + if sender != nil && senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) delete(servers, senderDomain) } @@ -231,10 +242,10 @@ func (r *Inputer) processRoomEvent( roomsMu: internal.NewMutexByRoom(), servers: serverRes.ServerNames, hadEvents: map[string]bool{}, - haveEvents: map[string]*gomatrixserverlib.Event{}, + haveEvents: map[string]gomatrixserverlib.PDU{}, } var stateSnapshot *parsedRespState - if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.Version()); err != nil { // Something went wrong with retrieving the missing state, so we can't // really do anything with the event other than reject it at this point. isRejected = true @@ -275,7 +286,9 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { isRejected = true rejectionErr = err logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) @@ -312,7 +325,7 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindNew && !isCreateEvent { // Check that the event passes authentication checks based on the // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs) + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs, r.Queryer) if err != nil { logger.WithError(err).Warn("Error authing soft-failed event") } @@ -388,12 +401,12 @@ func (r *Inputer) processRoomEvent( // we do this after calculating state for this event as we may need to get power levels var ( redactedEventID string - redactionEvent *gomatrixserverlib.Event - redactedEvent *gomatrixserverlib.Event + redactionEvent gomatrixserverlib.PDU + redactedEvent gomatrixserverlib.PDU ) if !isRejected && !isCreateEvent { - resolver := state.NewStateResolution(r.DB, roomInfo) - redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver) + resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer) + redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver, r.Queryer) if err != nil { return err } @@ -466,7 +479,7 @@ func (r *Inputer) processRoomEvent( Type: api.OutputTypeRedactedEvent, RedactedEvent: &api.OutputRedactedEvent{ RedactedEventID: redactedEventID, - RedactedBecause: redactionEvent.Headered(headered.RoomVersion), + RedactedBecause: &types.HeaderedEvent{PDU: redactionEvent}, }, }, }) @@ -476,8 +489,8 @@ func (r *Inputer) processRoomEvent( } // If guest_access changed and is not can_join, kick all guest users. - if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { - if err = r.kickGuests(ctx, event, roomInfo); err != nil { + if event.Type() == spec.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { + if err = r.kickGuests(ctx, event, roomInfo); err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation") } } @@ -489,10 +502,10 @@ func (r *Inputer) processRoomEvent( } // handleRemoteRoomUpgrade updates published rooms and room aliases -func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixserverlib.Event) error { +func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { oldRoomID := event.RoomID() newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str - return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender()) + return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, string(event.SenderID())) } // processStateBefore works out what the state is before the event and @@ -508,9 +521,9 @@ func (r *Inputer) processStateBefore( missingPrev bool, ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared. - event := input.Event.Unwrap() - isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") - var stateBeforeEvent []*gomatrixserverlib.Event + event := input.Event.PDU + isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("") + var stateBeforeEvent []gomatrixserverlib.PDU switch { case isCreateEvent: // There's no state before a create event so there is nothing @@ -523,9 +536,9 @@ func (r *Inputer) processStateBefore( if err != nil { return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err) } - stateBeforeEvent = make([]*gomatrixserverlib.Event, 0, len(stateEvents)) + stateBeforeEvent = make([]gomatrixserverlib.PDU, 0, len(stateEvents)) for _, entry := range stateEvents { - stateBeforeEvent = append(stateBeforeEvent, entry.Event) + stateBeforeEvent = append(stateBeforeEvent, entry.PDU) } case missingPrev: // We don't know all of the prev events, so we can't work out @@ -544,9 +557,9 @@ func (r *Inputer) processStateBefore( // will include the history visibility here even though we don't // actually need it for auth, because we want to send it in the // output events. - tuplesNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event}).Tuples() + tuplesNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event}).Tuples() tuplesNeeded = append(tuplesNeeded, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomHistoryVisibility, + EventType: spec.MRoomHistoryVisibility, StateKey: "", }) stateBeforeReq := &api.QueryStateAfterEventsRequest{ @@ -566,20 +579,28 @@ func (r *Inputer) processStateBefore( rejectionErr = fmt.Errorf("prev events of %q are not known", event.EventID()) return default: - stateBeforeEvent = gomatrixserverlib.UnwrapEventHeaders(stateBeforeRes.StateEvents) + stateBeforeEvent = make([]gomatrixserverlib.PDU, len(stateBeforeRes.StateEvents)) + for i := range stateBeforeRes.StateEvents { + stateBeforeEvent[i] = stateBeforeRes.StateEvents[i].PDU + } } } // At this point, stateBeforeEvent should be populated either by // the supplied state in the input request, or from the prev events. // Check whether the event is allowed or not. - stateBeforeAuth := gomatrixserverlib.NewAuthEvents(stateBeforeEvent) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil { + stateBeforeAuth := gomatrixserverlib.NewAuthEvents( + gomatrixserverlib.ToPDUs(stateBeforeEvent), + ) + if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }); rejectionErr != nil { + rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) return } // Work out what the history visibility was at the time of the // event. for _, event := range stateBeforeEvent { - if event.Type() != gomatrixserverlib.MRoomHistoryVisibility || !event.StateKeyEquals("") { + if event.Type() != spec.MRoomHistoryVisibility || !event.StateKeyEquals("") { continue } if hisVis, err := event.HistoryVisibility(); err == nil { @@ -602,11 +623,11 @@ func (r *Inputer) fetchAuthEvents( ctx context.Context, logger *logrus.Entry, roomInfo *types.RoomInfo, - virtualHost gomatrixserverlib.ServerName, - event *gomatrixserverlib.HeaderedEvent, + virtualHost spec.ServerName, + event *types.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, known map[string]*types.Event, - servers []gomatrixserverlib.ServerName, + servers []spec.ServerName, ) error { trace, ctx := internal.StartRegion(ctx, "fetchAuthEvents") defer trace.EndRegion() @@ -619,7 +640,7 @@ func (r *Inputer) fetchAuthEvents( for _, authEventID := range authEventIDs { authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, []string{authEventID}) - if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { + if err != nil || len(authEvents) == 0 || authEvents[0].PDU == nil { unknown[authEventID] = struct{}{} continue } @@ -634,7 +655,7 @@ func (r *Inputer) fetchAuthEvents( } known[authEventID] = &ev // don't take the pointer of the iterated event if !isRejected { - if err = auth.AddEvent(ev.Event); err != nil { + if err = auth.AddEvent(ev.PDU); err != nil { return fmt.Errorf("auth.AddEvent: %w", err) } } @@ -653,7 +674,7 @@ func (r *Inputer) fetchAuthEvents( // Request the entire auth chain for the event in question. This should // contain all of the auth events — including ones that we already know — // so we'll need to filter through those in the next section. - res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.RoomVersion, event.RoomID(), event.EventID()) + res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.Version(), event.RoomID(), event.EventID()) if err != nil { logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) continue @@ -670,7 +691,7 @@ func (r *Inputer) fetchAuthEvents( isRejected := false nextAuthEvent: for _, authEvent := range gomatrixserverlib.ReverseTopologicalOrdering( - res.AuthEvents.UntrustedEvents(event.RoomVersion), + gomatrixserverlib.ToPDUs(res.AuthEvents.UntrustedEvents(event.Version())), gomatrixserverlib.TopologicalOrderByAuthEvents, ) { // If we already know about this event from the database then we don't @@ -683,7 +704,9 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := authEvent.VerifyEventSignatures(ctx, r.FSAPI.KeyRing()); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue nextAuthEvent } @@ -699,7 +722,9 @@ nextAuthEvent: } // Check if the auth event should be rejected. - err := gomatrixserverlib.Allowed(authEvent, auth) + err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }) if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } @@ -738,7 +763,7 @@ nextAuthEvent: // Now we know about this event, it was stored and the signatures were OK. known[authEvent.EventID()] = &types.Event{ EventNID: eventNID, - Event: authEvent, + PDU: authEvent, } } @@ -750,7 +775,7 @@ func (r *Inputer) calculateAndSetState( input *api.InputRoomEvent, roomInfo *types.RoomInfo, stateAtEvent *types.StateAtEvent, - event *gomatrixserverlib.Event, + event gomatrixserverlib.PDU, isRejected bool, ) error { trace, ctx := internal.StartRegion(ctx, "calculateAndSetState") @@ -762,7 +787,7 @@ func (r *Inputer) calculateAndSetState( return fmt.Errorf("r.DB.GetRoomUpdater: %w", err) } defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - roomState := state.NewStateResolution(updater, roomInfo) + roomState := state.NewStateResolution(updater, roomInfo, r.Queryer) if input.HasState { // We've been told what the state at the event is so we don't need to calculate it. @@ -792,13 +817,16 @@ func (r *Inputer) calculateAndSetState( } // kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited. -func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error { +func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo) error { membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) if err != nil { return err } - memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs) + if roomInfo == nil { + return types.ErrorInvalidRoomInfo + } + memberEvents, err := r.DB.Events(ctx, roomInfo.RoomVersion, membershipNIDs) if err != nil { return err } @@ -812,21 +840,26 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event return err } + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + prevEvents := latestRes.LatestEvents for _, memberEvent := range memberEvents { if memberEvent.StateKey() == nil { continue } - localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) + memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey())) if err != nil { continue } accountRes := &userAPI.QueryAccountByLocalpartResponse{} if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ - Localpart: localpart, - ServerName: senderDomain, + Localpart: memberUserID.Local(), + ServerName: memberUserID.Domain(), }, accountRes); err != nil { return err } @@ -842,14 +875,14 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { return err } - memberContent.Membership = gomatrixserverlib.Leave + memberContent.Membership = spec.Leave stateKey := *memberEvent.StateKey() - fledglingEvent := &gomatrixserverlib.EventBuilder{ + fledglingEvent := &gomatrixserverlib.ProtoEvent{ RoomID: event.RoomID(), - Type: gomatrixserverlib.MRoomMember, + Type: spec.MRoomMember, StateKey: &stateKey, - Sender: stateKey, + SenderID: stateKey, PrevEvents: prevEvents, } @@ -857,12 +890,27 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event return err } - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) + eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(fledglingEvent) if err != nil { return err } - event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + + userID, err := spec.NewUserID(stateKey, true) + if err != nil { + return err + } + + signingIdentity, err := r.SigningIdentity(ctx, *validRoomID, *userID) + if err != nil { + return err + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, &signingIdentity, time.Now(), &eventsNeeded, latestRes) if err != nil { return err } @@ -870,12 +918,10 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event inputEvents = append(inputEvents, api.InputRoomEvent{ Kind: api.KindNew, Event: event, - Origin: senderDomain, - SendAsServer: string(senderDomain), + Origin: memberUserID.Domain(), + SendAsServer: string(memberUserID.Domain()), }) - prevEvents = []gomatrixserverlib.EventReference{ - event.EventReference(), - } + prevEvents = []string{event.EventID()} } inputReq := &api.InputRoomEventsRequest{ @@ -883,5 +929,6 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event Asynchronous: true, // Needs to be async, as we otherwise create a deadlock } inputRes := &api.InputRoomEventsResponse{} - return r.InputRoomEvents(ctx, inputReq, inputRes) + r.InputRoomEvents(ctx, inputReq, inputRes) + return nil } diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 818e7715c6..4ee6d21106 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/test" ) @@ -17,29 +18,29 @@ func Test_EventAuth(t *testing.T) { room2 := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) authEventIDs := make([]string, 0, 4) - authEvents := []*gomatrixserverlib.Event{} + authEvents := []gomatrixserverlib.PDU{} // Add the legal auth events from room2 for _, x := range room2.Events() { - if x.Type() == gomatrixserverlib.MRoomCreate { + if x.Type() == spec.MRoomCreate { authEventIDs = append(authEventIDs, x.EventID()) - authEvents = append(authEvents, x.Event) + authEvents = append(authEvents, x.PDU) } - if x.Type() == gomatrixserverlib.MRoomPowerLevels { + if x.Type() == spec.MRoomPowerLevels { authEventIDs = append(authEventIDs, x.EventID()) - authEvents = append(authEvents, x.Event) + authEvents = append(authEvents, x.PDU) } - if x.Type() == gomatrixserverlib.MRoomJoinRules { + if x.Type() == spec.MRoomJoinRules { authEventIDs = append(authEventIDs, x.EventID()) - authEvents = append(authEvents, x.Event) + authEvents = append(authEvents, x.PDU) } } // Add the illegal auth event from room1 (rooms are different) for _, x := range room1.Events() { - if x.Type() == gomatrixserverlib.MRoomMember { + if x.Type() == spec.MRoomMember { authEventIDs = append(authEventIDs, x.EventID()) - authEvents = append(authEvents, x.Event) + authEvents = append(authEvents, x.PDU) } } @@ -57,7 +58,9 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.Event, &allower); err == nil { + if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) + }); err == nil { t.Fatalf("event should not be allowed, but it was") } } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 09db184314..940783e03d 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -53,7 +53,7 @@ func (r *Inputer) updateLatestEvents( ctx context.Context, roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, - event *gomatrixserverlib.Event, + event gomatrixserverlib.PDU, sendAsServer string, transactionID *api.TransactionID, rewritesState bool, @@ -101,7 +101,7 @@ type latestEventsUpdater struct { updater *shared.RoomUpdater roomInfo *types.RoomInfo stateAtEvent types.StateAtEvent - event *gomatrixserverlib.Event + event gomatrixserverlib.PDU transactionID *api.TransactionID rewritesState bool // Which server to send this event as. @@ -154,8 +154,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { extremitiesChanged, err := u.calculateLatest( u.oldLatest, u.event, types.StateAtEventAndReference{ - EventReference: u.event.EventReference(), - StateAtEvent: u.stateAtEvent, + EventID: u.event.EventID(), + StateAtEvent: u.stateAtEvent, }, ) if err != nil { @@ -213,7 +213,7 @@ func (u *latestEventsUpdater) latestState() error { defer trace.EndRegion() var err error - roomState := state.NewStateResolution(u.updater, u.roomInfo) + roomState := state.NewStateResolution(u.updater, u.roomInfo, u.api.Queryer) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the @@ -326,7 +326,7 @@ func (u *latestEventsUpdater) latestState() error { // true if the new event is included in those extremites, false otherwise. func (u *latestEventsUpdater) calculateLatest( oldLatest []types.StateAtEventAndReference, - newEvent *gomatrixserverlib.Event, + newEvent gomatrixserverlib.PDU, newStateAndRef types.StateAtEventAndReference, ) (bool, error) { trace, _ := internal.StartRegion(u.ctx, "calculateLatest") @@ -349,7 +349,7 @@ func (u *latestEventsUpdater) calculateLatest( // If the "new" event is already referenced by an existing event // then do nothing - it's not a candidate to be a new extremity if // it has been referenced. - if referenced, err := u.updater.IsReferenced(newEvent.EventReference()); err != nil { + if referenced, err := u.updater.IsReferenced(newEvent.EventID()); err != nil { return false, fmt.Errorf("u.updater.IsReferenced(new): %w", err) } else if referenced { u.latest = oldLatest @@ -360,7 +360,7 @@ func (u *latestEventsUpdater) calculateLatest( // have entries in the previous events table. If they do then we // will no longer include them as forward extremities. for k, l := range existingRefs { - referenced, err := u.updater.IsReferenced(l.EventReference) + referenced, err := u.updater.IsReferenced(l.EventID) if err != nil { return false, fmt.Errorf("u.updater.IsReferenced: %w", err) } else if referenced { @@ -393,7 +393,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) } ore := api.OutputNewRoomEvent{ - Event: u.event.Headered(u.roomInfo.RoomVersion), + Event: &types.HeaderedEvent{PDU: u.event}, RewritesState: u.rewritesState, LastSentEventID: u.lastEventIDSent, LatestEventIDs: latestEventIDs, diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 4028f0b5ea..c46f8dba13 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -18,7 +18,7 @@ import ( "context" "fmt" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" @@ -54,7 +54,7 @@ func (r *Inputer) updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := updater.Events(ctx, nil, eventNIDs) + events, err := updater.Events(ctx, "", eventNIDs) if err != nil { return nil, err } @@ -71,7 +71,7 @@ func (r *Inputer) updateMemberships( if change.addedEventNID != 0 { ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID) } - if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { + if updates, err = r.updateMembership(ctx, updater, targetUserNID, re, ae, updates); err != nil { return nil, err } } @@ -79,6 +79,7 @@ func (r *Inputer) updateMemberships( } func (r *Inputer) updateMembership( + ctx context.Context, updater *shared.RoomUpdater, targetUserNID types.EventStateKeyNID, remove, add *types.Event, @@ -86,7 +87,7 @@ func (r *Inputer) updateMembership( ) ([]api.OutputEvent, error) { var err error // Default the membership to Leave if no event was added or removed. - newMembership := gomatrixserverlib.Leave + newMembership := spec.Leave if add != nil { newMembership, err = add.Membership() if err != nil { @@ -96,7 +97,7 @@ func (r *Inputer) updateMembership( var targetLocal bool if add != nil { - targetLocal = r.isLocalTarget(add) + targetLocal = r.isLocalTarget(ctx, add) } mu, err := updater.MembershipUpdater(targetUserNID, targetLocal) @@ -120,13 +121,13 @@ func (r *Inputer) updateMembership( } switch newMembership { - case gomatrixserverlib.Invite: + case spec.Invite: return helpers.UpdateToInviteMembership(mu, add, updates, updater.RoomVersion()) - case gomatrixserverlib.Join: + case spec.Join: return updateToJoinMembership(mu, add, updates) - case gomatrixserverlib.Leave, gomatrixserverlib.Ban: + case spec.Leave, spec.Ban: return updateToLeaveMembership(mu, add, newMembership, updates) - case gomatrixserverlib.Knock: + case spec.Knock: return updateToKnockMembership(mu, add, updates) default: panic(fmt.Errorf( @@ -135,11 +136,18 @@ func (r *Inputer) updateMembership( } } -func (r *Inputer) isLocalTarget(event *types.Event) bool { +func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool { isTargetLocalUser := false if statekey := event.StateKey(); statekey != nil { - _, domain, _ := gomatrixserverlib.SplitID('@', *statekey) - isTargetLocalUser = domain == r.ServerName + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return isTargetLocalUser + } + userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey)) + if err != nil || userID == nil { + return isTargetLocalUser + } + isTargetLocalUser = userID.Domain() == r.ServerName } return isTargetLocalUser } @@ -160,9 +168,10 @@ func updateToJoinMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, - Membership: gomatrixserverlib.Join, + RoomID: add.RoomID(), + Membership: spec.Join, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } @@ -186,9 +195,10 @@ func updateToLeaveMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, + RoomID: add.RoomID(), Membership: newMembership, RetiredByEventID: add.EventID(), - TargetUserID: *add.StateKey(), + TargetSenderID: spec.SenderID(*add.StateKey()), }, }) } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 74b138741f..7ee84e4c06 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -9,6 +9,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -21,39 +22,40 @@ import ( ) type parsedRespState struct { - AuthEvents []*gomatrixserverlib.Event - StateEvents []*gomatrixserverlib.Event + AuthEvents []gomatrixserverlib.PDU + StateEvents []gomatrixserverlib.PDU } -func (p *parsedRespState) Events() []*gomatrixserverlib.Event { - eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents)) +func (p *parsedRespState) Events() []gomatrixserverlib.PDU { + eventsByID := make(map[string]gomatrixserverlib.PDU, len(p.AuthEvents)+len(p.StateEvents)) for i, event := range p.AuthEvents { eventsByID[event.EventID()] = p.AuthEvents[i] } for i, event := range p.StateEvents { eventsByID[event.EventID()] = p.StateEvents[i] } - allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID)) + allEvents := make([]gomatrixserverlib.PDU, 0, len(eventsByID)) for _, event := range eventsByID { allEvents = append(allEvents, event) } - return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents) + return gomatrixserverlib.ReverseTopologicalOrdering( + gomatrixserverlib.ToPDUs(allEvents), gomatrixserverlib.TopologicalOrderByAuthEvents) } type missingStateReq struct { log *logrus.Entry - virtualHost gomatrixserverlib.ServerName - origin gomatrixserverlib.ServerName + virtualHost spec.ServerName + origin spec.ServerName db storage.RoomDatabase roomInfo *types.RoomInfo inputer *Inputer keys gomatrixserverlib.JSONVerifier federation fedapi.RoomserverFederationAPI roomsMu *internal.MutexByRoom - servers []gomatrixserverlib.ServerName + servers []spec.ServerName hadEvents map[string]bool hadEventsMutex sync.Mutex - haveEvents map[string]*gomatrixserverlib.Event + haveEvents map[string]gomatrixserverlib.PDU haveEventsMutex sync.Mutex } @@ -61,7 +63,7 @@ type missingStateReq struct { // request, as called from processRoomEvent. // nolint:gocyclo func (t *missingStateReq) processEventWithMissingState( - ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, + ctx context.Context, e gomatrixserverlib.PDU, roomVersion gomatrixserverlib.RoomVersion, ) (*parsedRespState, error) { trace, ctx := internal.StartRegion(ctx, "processEventWithMissingState") defer trace.EndRegion() @@ -105,7 +107,7 @@ func (t *missingStateReq) processEventWithMissingState( for _, newEvent := range newEvents { err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, - Event: newEvent.Headered(roomVersion), + Event: &types.HeaderedEvent{PDU: newEvent}, Origin: t.origin, SendAsServer: api.DoNotSendToOtherServers, }) @@ -154,7 +156,7 @@ func (t *missingStateReq) processEventWithMissingState( } outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{ Kind: api.KindOutlier, - Event: outlier.Headered(roomVersion), + Event: &types.HeaderedEvent{PDU: outlier}, Origin: t.origin, }) } @@ -184,7 +186,7 @@ func (t *missingStateReq) processEventWithMissingState( err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, - Event: backwardsExtremity.Headered(roomVersion), + Event: &types.HeaderedEvent{PDU: backwardsExtremity}, Origin: t.origin, HasState: true, StateEventIDs: stateIDs, @@ -203,7 +205,7 @@ func (t *missingStateReq) processEventWithMissingState( for _, newEvent := range newEvents { err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, - Event: newEvent.Headered(roomVersion), + Event: &types.HeaderedEvent{PDU: newEvent}, Origin: t.origin, SendAsServer: api.DoNotSendToOtherServers, }) @@ -241,7 +243,7 @@ func (t *missingStateReq) processEventWithMissingState( return resolvedState, nil } -func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) { +func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e gomatrixserverlib.PDU, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) { trace, ctx := internal.StartRegion(ctx, "lookupResolvedStateBeforeEvent") defer trace.EndRegion() @@ -279,7 +281,7 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e resolvedState := &parsedRespState{} switch len(states) { case 0: - extremityIsCreate := e.Type() == gomatrixserverlib.MRoomCreate && e.StateKeyEquals("") + extremityIsCreate := e.Type() == spec.MRoomCreate && e.StateKeyEquals("") if !extremityIsCreate { // There are no previous states and this isn't the beginning of the // room - this is an error condition! @@ -291,7 +293,7 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e // use it as-is. There's no point in resolving it again. Only trust a // trustworthy state snapshot if it actually contains some state for all // non-create events, otherwise we need to resolve what came from federation. - isCreate := e.Type() == gomatrixserverlib.MRoomCreate && e.StateKeyEquals("") + isCreate := e.Type() == spec.MRoomCreate && e.StateKeyEquals("") if states[0].trustworthy && (isCreate || len(states[0].StateEvents) > 0) { resolvedState = states[0].parsedRespState break @@ -366,7 +368,7 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion return respState, false, nil } -func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixserverlib.Event { +func (t *missingStateReq) cacheAndReturn(ev gomatrixserverlib.PDU) gomatrixserverlib.PDU { t.haveEventsMutex.Lock() defer t.haveEventsMutex.Unlock() if cached, exists := t.haveEvents[ev.EventID()]; exists { @@ -381,7 +383,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even defer trace.EndRegion() var res parsedRespState - roomState := state.NewStateResolution(t.db, t.roomInfo) + roomState := state.NewStateResolution(t.db, t.roomInfo, t.inputer.Queryer) stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID}) if err != nil { t.log.WithError(err).Warnf("failed to get state after %s locally", eventID) @@ -396,16 +398,19 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even for _, entry := range stateEntries { stateEventNIDs = append(stateEventNIDs, entry.EventNID) } - stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs) + if t.roomInfo == nil { + return nil + } + stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomVersion, stateEventNIDs) if err != nil { t.log.WithError(err).Warnf("failed to load state events locally") return nil } - res.StateEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)) + res.StateEvents = make([]gomatrixserverlib.PDU, 0, len(stateEvents)) for _, ev := range stateEvents { // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this // processEvent request, which is better for memory. - res.StateEvents = append(res.StateEvents, t.cacheAndReturn(ev.Event)) + res.StateEvents = append(res.StateEvents, t.cacheAndReturn(ev.PDU)) t.hadEvent(ev.EventID()) } @@ -413,7 +418,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even stateEvents, stateEventNIDs, stateEntries, stateAtEvents = nil, nil, nil, nil // nolint:ineffassign missingAuthEvents := map[string]bool{} - res.AuthEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)*3) + res.AuthEvents = make([]gomatrixserverlib.PDU, 0, len(stateEvents)*3) for _, ev := range stateEvents { t.haveEventsMutex.Lock() for _, ae := range ev.AuthEventIDs() { @@ -438,7 +443,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even return nil } for i, ev := range events { - res.AuthEvents = append(res.AuthEvents, t.cacheAndReturn(events[i].Event)) + res.AuthEvents = append(res.AuthEvents, t.cacheAndReturn(events[i].PDU)) t.hadEvent(ev.EventID()) } } @@ -457,23 +462,29 @@ func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersio return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) } -func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) { +func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity gomatrixserverlib.PDU) (*parsedRespState, error) { trace, ctx := internal.StartRegion(ctx, "resolveStatesAndCheck") defer trace.EndRegion() - var authEventList []*gomatrixserverlib.Event - var stateEventList []*gomatrixserverlib.Event + var authEventList []gomatrixserverlib.PDU + var stateEventList []gomatrixserverlib.PDU for _, state := range states { authEventList = append(authEventList, state.AuthEvents...) stateEventList = append(stateEventList, state.StateEvents...) } - resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(roomVersion, stateEventList, authEventList) + resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }, + ) if err != nil { return nil, err } // apply the current event retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) @@ -503,7 +514,7 @@ retryAllowedState: // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events -func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { +func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.PDU, roomVersion gomatrixserverlib.RoomVersion) (newEvents []gomatrixserverlib.PDU, isGapFilled, prevStateKnown bool, err error) { trace, ctx := internal.StartRegion(ctx, "getMissingEvents") defer trace.EndRegion() @@ -513,9 +524,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err) } latestEvents := make([]string, len(latest)) - for i, ev := range latest { - latestEvents[i] = ev.EventID - t.hadEvent(ev.EventID) + for i := range latest { + latestEvents[i] = latest[i] + t.hadEvent(latest[i]) } var missingResp *fclient.RespMissingEvents @@ -556,9 +567,11 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve // Make sure events from the missingResp are using the cache - missing events // will be added and duplicates will be removed. - missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events)) + missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = ev.VerifyEventSignatures(ctx, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } missingEvents = append(missingEvents, t.cacheAndReturn(ev)) @@ -566,7 +579,8 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve logger.Debugf("get_missing_events returned %d events (%d passed signature checks)", len(missingResp.Events), len(missingEvents)) // topologically sort and sanity check that we are making forward progress - newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingEvents, gomatrixserverlib.TopologicalOrderByPrevEvents) + newEvents = gomatrixserverlib.ReverseTopologicalOrdering( + gomatrixserverlib.ToPDUs(missingEvents), gomatrixserverlib.TopologicalOrderByPrevEvents) shouldHaveSomeEventIDs := e.PrevEventIDs() hasPrevEvent := false Event: @@ -597,7 +611,7 @@ Event: // If we retrieved back to the beginning of the room then there's nothing else // to do - we closed the gap. - if len(earliestNewEvent.PrevEventIDs()) == 0 && earliestNewEvent.Type() == gomatrixserverlib.MRoomCreate && earliestNewEvent.StateKeyEquals("") { + if len(earliestNewEvent.PrevEventIDs()) == 0 && earliestNewEvent.Type() == spec.MRoomCreate && earliestNewEvent.StateKeyEquals("") { return newEvents, true, t.isPrevStateKnown(ctx, e), nil } @@ -612,7 +626,7 @@ Event: return newEvents, true, t.isPrevStateKnown(ctx, e), nil } -func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e *gomatrixserverlib.Event) bool { +func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e gomatrixserverlib.PDU) bool { expected := len(e.PrevEventIDs()) state, err := t.db.StateAtEventIDs(ctx, e.PrevEventIDs()) if err != nil || len(state) != expected { @@ -641,31 +655,21 @@ func (t *missingStateReq) lookupMissingStateViaState( if err != nil { return nil, err } - s := fclient.RespState{ + + // Check that the returned state is valid. + authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ StateEvents: state.GetStateEvents(), AuthEvents: state.GetAuthEvents(), - } - // Check that the returned state is valid. - authEvents, stateEvents, err := s.Check(ctx, roomVersion, t.keys, nil) + }, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }) if err != nil { return nil, err } - parsedState := &parsedRespState{ + return &parsedRespState{ AuthEvents: authEvents, StateEvents: stateEvents, - } - // Cache the results of this state lookup and deduplicate anything we already - // have in the cache, freeing up memory. - // We load these as trusted as we called state.Check before which loaded them as untrusted. - for i, evJSON := range s.AuthEvents { - ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) - parsedState.AuthEvents[i] = t.cacheAndReturn(ev) - } - for i, evJSON := range s.StateEvents { - ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) - parsedState.StateEvents[i] = t.cacheAndReturn(ev) - } - return parsedState, nil + }, nil } func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( @@ -713,7 +717,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo } for i, ev := range events { - events[i].Event = t.cacheAndReturn(events[i].Event) + events[i].PDU = t.cacheAndReturn(events[i].PDU) t.hadEvent(ev.EventID()) evID := events[i].EventID() if missing[evID] { @@ -845,20 +849,25 @@ func (t *missingStateReq) createRespStateFromStateIDs( return &respState, nil } -func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) { +func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (gomatrixserverlib.PDU, error) { trace, ctx := internal.StartRegion(ctx, "lookupEvent") defer trace.EndRegion() + verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) + if err != nil { + return nil, err + } + if localFirst { // fetch from the roomserver events, err := t.db.EventsFromIDs(ctx, t.roomInfo, []string{missingEventID}) if err != nil { t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) } else if len(events) == 1 { - return events[0].Event, nil + return events[0].PDU, nil } } - var event *gomatrixserverlib.Event + var event gomatrixserverlib.PDU found := false for _, serverName := range t.servers { reqctx, cancel := context.WithTimeout(ctx, time.Second*30) @@ -876,7 +885,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs } continue } - event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion) + event, err = verImpl.NewEventFromUntrustedJSON(txn.PDUs[0]) if err != nil { t.log.WithError(err).WithField("missing_event_id", missingEventID).Warnf("Failed to parse event JSON of event returned from /event") continue @@ -888,14 +897,16 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := event.VerifyEventSignatures(ctx, t.keys); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } return t.cacheAndReturn(event), nil } -func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { +func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error { authUsingState := gomatrixserverlib.NewAuthEvents(nil) for i := range stateEvents { err := authUsingState.AddEvent(stateEvents[i]) @@ -903,7 +914,7 @@ func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserv return err } } - return gomatrixserverlib.Allowed(e, &authUsingState) + return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender) } func (t *missingStateReq) hadEvent(eventID string) { diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go index 51c50c37a5..f435181a03 100644 --- a/roomserver/internal/input/input_test.go +++ b/roomserver/internal/input/input_test.go @@ -10,6 +10,7 @@ import ( "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" @@ -35,16 +36,16 @@ func TestSingleTransactionOnInput(t *testing.T) { ctx, cancel := context.WithDeadline(processCtx.Context(), deadline) defer cancel() - event, err := gomatrixserverlib.NewEventFromTrustedJSON( + event, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV6).NewEventFromTrustedJSON( []byte(`{"auth_events":[],"content":{"creator":"@neilalexander:dendrite.matrix.org","room_version":"6"},"depth":1,"hashes":{"sha256":"jqOqdNEH5r0NiN3xJtj0u5XUVmRqq9YvGbki1wxxuuM"},"origin":"dendrite.matrix.org","origin_server_ts":1644595362726,"prev_events":[],"prev_state":[],"room_id":"!jSZZRknA6GkTBXNP:dendrite.matrix.org","sender":"@neilalexander:dendrite.matrix.org","signatures":{"dendrite.matrix.org":{"ed25519:6jB2aB":"bsQXO1wketf1OSe9xlndDIWe71W9KIundc6rBw4KEZdGPW7x4Tv4zDWWvbxDsG64sS2IPWfIm+J0OOozbrWIDw"}},"state_key":"","type":"m.room.create"}`), - false, gomatrixserverlib.RoomVersionV6, + false, ) if err != nil { t.Fatal(err) } in := api.InputRoomEvent{ Kind: api.KindOutlier, // don't panic if we generate an output event - Event: event.Headered(gomatrixserverlib.RoomVersionV6), + Event: &types.HeaderedEvent{PDU: event}, } inputter := &input.Inputer{ diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index f35e40bc91..12b557f510 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -26,8 +26,11 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -39,61 +42,48 @@ type Admin struct { Leaver *Leaver } -// PerformEvacuateRoom will remove all local users from the given room. +// PerformAdminEvacuateRoom will remove all local users from the given room. func (r *Admin) PerformAdminEvacuateRoom( ctx context.Context, - req *api.PerformAdminEvacuateRoomRequest, - res *api.PerformAdminEvacuateRoomResponse, -) error { - roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) + roomID string, +) (affected []string, err error) { + roomInfo, err := r.DB.RoomInfo(ctx, roomID) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err), - } - return nil + return nil, err } if roomInfo == nil || roomInfo.IsStub() { - res.Error = &api.PerformError{ - Code: api.PerformErrorNoRoom, - Msg: fmt.Sprintf("Room %s not found", req.RoomID), - } - return nil + return nil, eventutil.ErrRoomNoExists{} } memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err), - } - return nil + return nil, err } - memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo.RoomVersion, memberNIDs) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.DB.Events: %s", err), - } - return nil + return nil, err } inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) - res.Affected = make([]string, 0, len(memberEvents)) + affected = make([]string, 0, len(memberEvents)) latestReq := &api.QueryLatestEventsAndStateRequest{ - RoomID: req.RoomID, + RoomID: roomID, } latestRes := &api.QueryLatestEventsAndStateResponse{} if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err), - } - return nil + return nil, err + } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err } prevEvents := latestRes.LatestEvents + var senderDomain spec.ServerName + var eventsNeeded gomatrixserverlib.StateNeeded + var identity *fclient.SigningIdentity + var event *types.HeaderedEvent for _, memberEvent := range memberEvents { if memberEvent.StateKey() == nil { continue @@ -101,57 +91,42 @@ func (r *Admin) PerformAdminEvacuateRoom( var memberContent gomatrixserverlib.MemberContent if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("json.Unmarshal: %s", err), - } - return nil + return nil, err } - memberContent.Membership = gomatrixserverlib.Leave + memberContent.Membership = spec.Leave stateKey := *memberEvent.StateKey() - fledglingEvent := &gomatrixserverlib.EventBuilder{ - RoomID: req.RoomID, - Type: gomatrixserverlib.MRoomMember, + fledglingEvent := &gomatrixserverlib.ProtoEvent{ + RoomID: roomID, + Type: spec.MRoomMember, StateKey: &stateKey, - Sender: stateKey, + SenderID: stateKey, PrevEvents: prevEvents, } - _, senderDomain, err := gomatrixserverlib.SplitID('@', fledglingEvent.Sender) - if err != nil { + userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(fledglingEvent.SenderID)) + if err != nil || userID == nil { continue } + senderDomain = userID.Domain() if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("json.Marshal: %s", err), - } - return nil + return nil, err } - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) + eventsNeeded, err = gomatrixserverlib.StateNeededForProtoEvent(fledglingEvent) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err), - } - return nil + return nil, err } - identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + identity, err = r.Cfg.Matrix.SigningIdentityFor(senderDomain) if err != nil { continue } - event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes) + event, err = eventutil.BuildEvent(ctx, fledglingEvent, identity, time.Now(), &eventsNeeded, latestRes) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err), - } - return nil + return nil, err } inputEvents = append(inputEvents, api.InputRoomEvent{ @@ -160,119 +135,100 @@ func (r *Admin) PerformAdminEvacuateRoom( Origin: senderDomain, SendAsServer: string(senderDomain), }) - res.Affected = append(res.Affected, stateKey) - prevEvents = []gomatrixserverlib.EventReference{ - event.EventReference(), - } + affected = append(affected, stateKey) + prevEvents = []string{event.EventID()} } inputReq := &api.InputRoomEventsRequest{ InputRoomEvents: inputEvents, - Asynchronous: true, + Asynchronous: false, } inputRes := &api.InputRoomEventsResponse{} - return r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) + r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) + return affected, nil } +// PerformAdminEvacuateUser will remove the given user from all rooms. func (r *Admin) PerformAdminEvacuateUser( ctx context.Context, - req *api.PerformAdminEvacuateUserRequest, - res *api.PerformAdminEvacuateUserResponse, -) error { - _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + userID string, +) (affected []string, err error) { + fullUserID, err := spec.NewUserID(userID, true) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("Malformed user ID: %s", err), - } - return nil + return nil, err } - if !r.Cfg.Matrix.IsLocalServerName(domain) { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: "Can only evacuate local users using this endpoint", - } - return nil + if !r.Cfg.Matrix.IsLocalServerName(fullUserID.Domain()) { + return nil, fmt.Errorf("can only evacuate local users using this endpoint") } - roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Join) - if err != nil && err != sql.ErrNoRows { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err), - } - return nil + roomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Join) + if err != nil { + return nil, err } - inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Invite) + inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Invite) if err != nil && err != sql.ErrNoRows { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err), - } - return nil + return nil, err } - for _, roomID := range append(roomIDs, inviteRoomIDs...) { + allRooms := append(roomIDs, inviteRoomIDs...) + affected = make([]string, 0, len(allRooms)) + for _, roomID := range allRooms { leaveReq := &api.PerformLeaveRequest{ RoomID: roomID, - UserID: req.UserID, + Leaver: *fullUserID, } leaveRes := &api.PerformLeaveResponse{} outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.Leaver.PerformLeave: %s", err), - } - return nil + return nil, err } - res.Affected = append(res.Affected, roomID) + affected = append(affected, roomID) if len(outputEvents) == 0 { continue } if err := r.Inputer.OutputProducer.ProduceRoomEvents(roomID, outputEvents); err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.Inputer.WriteOutputEvents: %s", err), - } - return nil + return nil, err } } - return nil + return affected, nil } +// PerformAdminPurgeRoom removes all traces for the given room from the database. func (r *Admin) PerformAdminPurgeRoom( ctx context.Context, - req *api.PerformAdminPurgeRoomRequest, - res *api.PerformAdminPurgeRoomResponse, + roomID string, ) error { // Validate we actually got a room ID and nothing else - if _, _, err := gomatrixserverlib.SplitID('!', req.RoomID); err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("Malformed room ID: %s", err), - } - return nil + if _, _, err := gomatrixserverlib.SplitID('!', roomID); err != nil { + return err } - logrus.WithField("room_id", req.RoomID).Warn("Purging room from roomserver") - if err := r.DB.PurgeRoom(ctx, req.RoomID); err != nil { - logrus.WithField("room_id", req.RoomID).WithError(err).Warn("Failed to purge room from roomserver") - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: err.Error(), - } - return nil + // Evacuate the room before purging it from the database + evacAffected, err := r.PerformAdminEvacuateRoom(ctx, roomID) + if err != nil { + logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to evacuate room before purging") + return err + } + + logrus.WithFields(logrus.Fields{ + "room_id": roomID, + "evacuated_users": len(evacAffected), + }).Warn("Evacuated room, purging room from roomserver now") + + logrus.WithField("room_id", roomID).Warn("Purging room from roomserver") + if err := r.DB.PurgeRoom(ctx, roomID); err != nil { + logrus.WithField("room_id", roomID).WithError(err).Warn("Failed to purge room from roomserver") + return err } - logrus.WithField("room_id", req.RoomID).Warn("Room purged from roomserver") + logrus.WithField("room_id", roomID).Warn("Room purged from roomserver, informing other components") - return r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{ + return r.Inputer.OutputProducer.ProduceRoomEvents(roomID, []api.OutputEvent{ { Type: api.OutputTypePurgeRoom, PurgeRoom: &api.OutputPurgeRoom{ - RoomID: req.RoomID, + RoomID: roomID, }, }, }) @@ -280,97 +236,85 @@ func (r *Admin) PerformAdminPurgeRoom( func (r *Admin) PerformAdminDownloadState( ctx context.Context, - req *api.PerformAdminDownloadStateRequest, - res *api.PerformAdminDownloadStateResponse, + roomID, userID string, serverName spec.ServerName, ) error { - _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', req.UserID) + fullUserID, err := spec.NewUserID(userID, true) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.Cfg.Matrix.SplitLocalID: %s", err), - } - return nil + return err } + senderDomain := fullUserID.Domain() - roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) + roomInfo, err := r.DB.RoomInfo(ctx, roomID) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err), - } - return nil + return err } if roomInfo == nil || roomInfo.IsStub() { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("room %q not found", req.RoomID), - } - return nil + return eventutil.ErrRoomNoExists{} } fwdExtremities, _, depth, err := r.DB.LatestEventIDs(ctx, roomInfo.RoomNID) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.DB.LatestEventIDs: %s", err), - } - return nil + return err } - authEventMap := map[string]*gomatrixserverlib.Event{} - stateEventMap := map[string]*gomatrixserverlib.Event{} + authEventMap := map[string]gomatrixserverlib.PDU{} + stateEventMap := map[string]gomatrixserverlib.PDU{} for _, fwdExtremity := range fwdExtremities { var state gomatrixserverlib.StateResponse - state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) + state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, serverName, roomID, fwdExtremity, roomInfo.RoomVersion) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity.EventID, err), - } - return nil + return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = authEvent.VerifyEventSignatures(ctx, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = stateEvent.VerifyEventSignatures(ctx, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } stateEventMap[stateEvent.EventID()] = stateEvent } } - authEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(authEventMap)) - stateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEventMap)) + authEvents := make([]*types.HeaderedEvent, 0, len(authEventMap)) + stateEvents := make([]*types.HeaderedEvent, 0, len(stateEventMap)) stateIDs := make([]string, 0, len(stateEventMap)) for _, authEvent := range authEventMap { - authEvents = append(authEvents, authEvent.Headered(roomInfo.RoomVersion)) + authEvents = append(authEvents, &types.HeaderedEvent{PDU: authEvent}) } for _, stateEvent := range stateEventMap { - stateEvents = append(stateEvents, stateEvent.Headered(roomInfo.RoomVersion)) + stateEvents = append(stateEvents, &types.HeaderedEvent{PDU: stateEvent}) stateIDs = append(stateIDs, stateEvent.EventID()) } - builder := &gomatrixserverlib.EventBuilder{ - Type: "org.matrix.dendrite.state_download", - Sender: req.UserID, - RoomID: req.RoomID, - Content: gomatrixserverlib.RawJSON("{}"), + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) + if err != nil { + return err + } + proto := &gomatrixserverlib.ProtoEvent{ + Type: "org.matrix.dendrite.state_download", + SenderID: string(senderID), + RoomID: roomID, + Content: spec.RawJSON("{}"), } - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) + eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err), - } - return nil + return fmt.Errorf("gomatrixserverlib.StateNeededForProtoEvent: %w", err) } queryRes := &api.QueryLatestEventsAndStateResponse{ @@ -386,13 +330,9 @@ func (r *Admin) PerformAdminDownloadState( return err } - ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes) + ev, err := eventutil.BuildEvent(ctx, proto, identity, time.Now(), &eventsNeeded, queryRes) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err), - } - return nil + return fmt.Errorf("eventutil.BuildEvent: %w", err) } inputReq := &api.InputRoomEventsRequest{ @@ -416,19 +356,10 @@ func (r *Admin) PerformAdminDownloadState( SendAsServer: string(r.Cfg.Matrix.ServerName), }) - if err := r.Inputer.InputRoomEvents(ctx, inputReq, inputRes); err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("r.Inputer.InputRoomEvents: %s", err), - } - return nil - } + r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) if inputRes.ErrMsg != "" { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: inputRes.ErrMsg, - } + return inputRes.Err() } return nil diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 23862b242a..33200e819c 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -37,13 +38,14 @@ import ( const maxBackfillServers = 5 type Backfiller struct { - IsLocalServerName func(gomatrixserverlib.ServerName) bool + IsLocalServerName func(spec.ServerName) bool DB storage.Database FSAPI federationAPI.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier + Querier api.QuerySenderIDAPI // The servers which should be preferred above other servers when backfilling - PreferServers []gomatrixserverlib.ServerName + PreferServers []spec.ServerName } // PerformBackfill implements api.RoomServerQueryAPI @@ -78,14 +80,14 @@ func (r *Backfiller) PerformBackfill( } // Scan the event tree for events to send back. - resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) + resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r.Querier) if err != nil { return err } // Retrieve events from the list that was filled previously. If we fail to get // events from the database then attempt once to get them from federation instead. - var loadedEvents []*gomatrixserverlib.Event + var loadedEvents []gomatrixserverlib.PDU loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info, resultNIDs) if err != nil { if _, ok := err.(types.MissingEventError); ok { @@ -98,7 +100,7 @@ func (r *Backfiller) PerformBackfill( if _, ok := redactEventIDs[event.EventID()]; ok { event.Redact() } - response.Events = append(response.Events, event.Headered(info.RoomVersion)) + response.Events = append(response.Events, &types.HeaderedEvent{PDU: event}) } return err @@ -112,7 +114,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform if info == nil || info.IsStub() { return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) } - requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) + requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers, info.RoomVersion) // Request 100 items regardless of what the query asks for. // We don't want to go much higher than this. // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass @@ -120,7 +122,9 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( ctx, req.VirtualHost, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Querier.QueryUserIDForSender(ctx, roomID, senderID) + }, ) // Only return an error if we really couldn't get any events. if err != nil && len(events) == 0 { @@ -132,7 +136,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) // persist these new events - auth checks have already been done - roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) + roomNID, backfilledEventMap := persistEvents(ctx, r.DB, r.Querier, events) for _, ev := range backfilledEventMap { // now add state for these events @@ -167,7 +171,10 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point. - res.Events = events + res.Events = make([]*types.HeaderedEvent, len(events)) + for i := range events { + res.Events[i] = &types.HeaderedEvent{PDU: events[i]} + } res.HistoryVisibility = requester.historyVisiblity return nil } @@ -175,7 +182,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just // best effort. func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - backfillRequester *backfillRequester, stateIDs []string, virtualHost gomatrixserverlib.ServerName) { + backfillRequester *backfillRequester, stateIDs []string, virtualHost spec.ServerName) { servers := backfillRequester.servers @@ -185,7 +192,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom util.GetLogger(ctx).WithError(err).Warn("cannot query missing events") return } - missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event + missingMap := make(map[string]*types.HeaderedEvent) // id -> event for _, id := range stateIDs { if _, ok := nidMap[id]; !ok { missingMap[id] = nil @@ -206,7 +213,9 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue } loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Querier.QueryUserIDForSender(ctx, roomID, senderID) + }) if err != nil { logger.WithError(err).Warn("failed to load and verify event") continue @@ -226,63 +235,68 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom logger.WithError(err).Warn("event failed PDU checks") continue } - missingMap[id] = res.Event + missingMap[id] = &types.HeaderedEvent{PDU: res.Event} } } } - var newEvents []*gomatrixserverlib.HeaderedEvent + var newEvents []gomatrixserverlib.PDU for _, ev := range missingMap { if ev != nil { - newEvents = append(newEvents, ev) + newEvents = append(newEvents, ev.PDU) } } util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) - persistEvents(ctx, r.DB, newEvents) + persistEvents(ctx, r.DB, r.Querier, newEvents) } // backfillRequester implements gomatrixserverlib.BackfillRequester type backfillRequester struct { db storage.Database fsAPI federationAPI.RoomserverFederationAPI - virtualHost gomatrixserverlib.ServerName - isLocalServerName func(gomatrixserverlib.ServerName) bool - preferServer map[gomatrixserverlib.ServerName]bool + querier api.QuerySenderIDAPI + virtualHost spec.ServerName + isLocalServerName func(spec.ServerName) bool + preferServer map[spec.ServerName]bool bwExtrems map[string][]string // per-request state - servers []gomatrixserverlib.ServerName + servers []spec.ServerName eventIDToBeforeStateIDs map[string][]string - eventIDMap map[string]*gomatrixserverlib.Event + eventIDMap map[string]gomatrixserverlib.PDU historyVisiblity gomatrixserverlib.HistoryVisibility - roomInfo types.RoomInfo + roomVersion gomatrixserverlib.RoomVersion } func newBackfillRequester( db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, - virtualHost gomatrixserverlib.ServerName, - isLocalServerName func(gomatrixserverlib.ServerName) bool, - bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName, + querier api.QuerySenderIDAPI, + virtualHost spec.ServerName, + isLocalServerName func(spec.ServerName) bool, + bwExtrems map[string][]string, preferServers []spec.ServerName, + roomVersion gomatrixserverlib.RoomVersion, ) *backfillRequester { - preferServer := make(map[gomatrixserverlib.ServerName]bool) + preferServer := make(map[spec.ServerName]bool) for _, p := range preferServers { preferServer[p] = true } return &backfillRequester{ db: db, fsAPI: fsAPI, + querier: querier, virtualHost: virtualHost, isLocalServerName: isLocalServerName, eventIDToBeforeStateIDs: make(map[string][]string), - eventIDMap: make(map[string]*gomatrixserverlib.Event), + eventIDMap: make(map[string]gomatrixserverlib.PDU), bwExtrems: bwExtrems, preferServer: preferServer, historyVisiblity: gomatrixserverlib.HistoryVisibilityShared, + roomVersion: roomVersion, } } -func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent *gomatrixserverlib.HeaderedEvent) ([]string, error) { - b.eventIDMap[targetEvent.EventID()] = targetEvent.Unwrap() +func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.PDU) ([]string, error) { + b.eventIDMap[targetEvent.EventID()] = targetEvent if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok { return ids, nil } @@ -304,7 +318,7 @@ func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent if !ok { goto FederationHit } - newStateIDs := b.calculateNewStateIDs(targetEvent.Unwrap(), prevEvent, prevEventStateIDs) + newStateIDs := b.calculateNewStateIDs(targetEvent, prevEvent, prevEventStateIDs) if newStateIDs != nil { b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs return newStateIDs, nil @@ -333,7 +347,7 @@ FederationHit: return nil, lastErr } -func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent *gomatrixserverlib.Event, prevEventStateIDs []string) []string { +func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.PDU, prevEventStateIDs []string) []string { newStateIDs := prevEventStateIDs[:] if prevEvent.StateKey() == nil { // state is the same as the previous event @@ -371,7 +385,7 @@ func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent *gomatri } func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - event *gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { + event gomatrixserverlib.PDU, eventIDs []string) (map[string]gomatrixserverlib.PDU, error) { // try to fetch the events from the database first events, err := b.ProvideEvents(roomVer, eventIDs) @@ -381,7 +395,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr } else { logrus.Infof("Fetched %d/%d events from the database", len(events), len(eventIDs)) if len(events) == len(eventIDs) { - result := make(map[string]*gomatrixserverlib.Event) + result := make(map[string]gomatrixserverlib.PDU) for i := range events { result[events[i].EventID()] = events[i] b.eventIDMap[events[i].EventID()] = events[i] @@ -415,7 +429,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr // It returns a list of servers which can be queried for backfill requests. These servers // will be servers that are in the room already. The entries at the beginning are preferred servers // and will be tried first. An empty list will fail the request. -func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName { +func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []spec.ServerName { // eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use // its successor, so look it up. successor := "" @@ -452,14 +466,14 @@ FindSuccessor: return nil } - stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID) + stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID, b.querier) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, b.querier, info, stateEntries, b.virtualHost) b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") @@ -478,19 +492,23 @@ FindSuccessor: memberEvents = append(memberEvents, memberEventsFromVis...) // Store the server names in a temporary map to avoid duplicates. - serverSet := make(map[gomatrixserverlib.ServerName]bool) + serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil { - serverSet[senderDomain] = true + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + continue + } + if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil { + serverSet[sender.Domain()] = true } } - var servers []gomatrixserverlib.ServerName + var servers []spec.ServerName for server := range serverSet { if b.isLocalServerName(server) { continue } if b.preferServer[server] { // insert at the front - servers = append([]gomatrixserverlib.ServerName{server}, servers...) + servers = append([]spec.ServerName{server}, servers...) } else { // insert at the back servers = append(servers, server) } @@ -505,14 +523,14 @@ FindSuccessor: // Backfill performs a backfill request to the given server. // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid -func (b *backfillRequester) Backfill(ctx context.Context, origin, server gomatrixserverlib.ServerName, roomID string, +func (b *backfillRequester) Backfill(ctx context.Context, origin, server spec.ServerName, roomID string, limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { tx, err := b.fsAPI.Backfill(ctx, origin, server, roomID, limit, fromEventIDs) return tx, err } -func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) { +func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.PDU, error) { ctx := context.Background() nidMap, err := b.db.EventNIDs(ctx, eventIDs) if err != nil { @@ -521,22 +539,18 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, } eventNIDs := make([]types.EventNID, len(nidMap)) i := 0 - roomNID := b.roomInfo.RoomNID for _, nid := range nidMap { eventNIDs[i] = nid.EventNID i++ - if roomNID == 0 { - roomNID = nid.RoomNID - } } - eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs) + eventsWithNids, err := b.db.Events(ctx, b.roomVersion, eventNIDs) if err != nil { logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") return nil, err } - events := make([]*gomatrixserverlib.Event, len(eventsWithNids)) + events := make([]gomatrixserverlib.PDU, len(eventsWithNids)) for i := range eventsWithNids { - events[i] = eventsWithNids[i].Event + events[i] = eventsWithNids[i].PDU } return events, nil } @@ -546,8 +560,8 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, // TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just // pull all events and then filter by that table. func joinEventsFromHistoryVisibility( - ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, - thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { + ctx context.Context, db storage.RoomDatabase, querier api.QuerySenderIDAPI, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, + thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { var eventNIDs []types.EventNID for _, entry := range stateEntries { @@ -559,19 +573,22 @@ func joinEventsFromHistoryVisibility( } // Get all of the events in this state - stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) + if roomInfo == nil { + return nil, gomatrixserverlib.HistoryVisibilityJoined, types.ErrorInvalidRoomInfo + } + stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs) if err != nil { // even though the default should be shared, restricting the visibility to joined // feels more secure here. return nil, gomatrixserverlib.HistoryVisibilityJoined, err } - events := make([]*gomatrixserverlib.Event, len(stateEvents)) + events := make([]gomatrixserverlib.PDU, len(stateEvents)) for i := range stateEvents { - events[i] = stateEvents[i].Event + events[i] = stateEvents[i].PDU } // Can we see events in the room? - canSeeEvents := auth.IsServerAllowed(thisServer, true, events) + canSeeEvents := auth.IsServerAllowed(ctx, querier, thisServer, true, events) visibility := auth.HistoryVisibilityForRoom(events) if !canSeeEvents { logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) @@ -582,11 +599,11 @@ func joinEventsFromHistoryVisibility( if err != nil { return nil, visibility, err } - evs, err := db.Events(ctx, roomInfo, joinEventNIDs) + evs, err := db.Events(ctx, roomInfo.RoomVersion, joinEventNIDs) return evs, visibility, err } -func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { +func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) { var roomNID types.RoomNID var eventNID types.EventNID backfilledEventMap := make(map[string]types.Event) @@ -603,7 +620,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs i++ } - roomInfo, err := db.GetOrCreateRoomInfo(ctx, ev.Unwrap()) + roomInfo, err := db.GetOrCreateRoomInfo(ctx, ev) if err != nil { logrus.WithError(err).Error("failed to get or create roomNID") continue @@ -622,15 +639,15 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs continue } - eventNID, _, err = db.StoreEvent(ctx, ev.Unwrap(), roomInfo, eventTypeNID, eventStateKeyNID, authNids, false) + eventNID, _, err = db.StoreEvent(ctx, ev, roomInfo, eventTypeNID, eventStateKeyNID, authNids, false) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") continue } - resolver := state.NewStateResolution(db, roomInfo) + resolver := state.NewStateResolution(db, roomInfo, querier) - _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Unwrap(), &resolver) + _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver, querier) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") continue @@ -639,12 +656,12 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs // It's also possible for this event to be a redaction which results in another event being // redacted, which we don't care about since we aren't returning it in this backfill. if redactedEvent != nil && redactedEvent.EventID() == ev.EventID() { - ev = redactedEvent.Headered(ev.RoomVersion) + ev = redactedEvent events[j] = ev } backfilledEventMap[ev.EventID()] = types.Event{ EventNID: eventNID, - Event: ev.Unwrap(), + PDU: ev, } } return roomNID, backfilledEventMap diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go new file mode 100644 index 0000000000..8c96564533 --- /dev/null +++ b/roomserver/internal/perform/perform_create_room.go @@ -0,0 +1,594 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform + +import ( + "context" + "crypto/ed25519" + "encoding/json" + "fmt" + "net/http" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +const ( + historyVisibilityShared = "shared" +) + +type Creator struct { + DB storage.Database + Cfg *config.RoomServer + RSAPI api.RoomserverInternalAPI +} + +// PerformCreateRoom handles all the steps necessary to create a new room. +// nolint: gocyclo +func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) { + verImpl, err := gomatrixserverlib.GetRoomVersion(createRequest.RoomVersion) + if err != nil { + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("unknown room version"), + } + } + + createContent := map[string]interface{}{} + if len(createRequest.CreationContent) > 0 { + if err = json.Unmarshal(createRequest.CreationContent, &createContent); err != nil { + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed") + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("invalid create content"), + } + } + } + + _, err = c.DB.AssignRoomNID(ctx, roomID, createRequest.RoomVersion) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to assign roomNID") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + var senderID spec.SenderID + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + // create user room key if needed + key, keyErr := c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if keyErr != nil { + util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + senderID = spec.SenderIDFromPseudoIDKey(key) + } else { + senderID = spec.SenderID(userID.String()) + } + createContent["creator"] = senderID + createContent["room_version"] = createRequest.RoomVersion + powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID)) + joinRuleContent := gomatrixserverlib.JoinRuleContent{ + JoinRule: spec.Invite, + } + historyVisibilityContent := gomatrixserverlib.HistoryVisibilityContent{ + HistoryVisibility: historyVisibilityShared, + } + + if createRequest.PowerLevelContentOverride != nil { + // Merge powerLevelContentOverride fields by unmarshalling it atop the defaults + err = json.Unmarshal(createRequest.PowerLevelContentOverride, &powerLevelContent) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed") + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("malformed power_level_content_override"), + } + } + } + + var guestsCanJoin bool + switch createRequest.StatePreset { + case spec.PresetPrivateChat: + joinRuleContent.JoinRule = spec.Invite + historyVisibilityContent.HistoryVisibility = historyVisibilityShared + guestsCanJoin = true + case spec.PresetTrustedPrivateChat: + joinRuleContent.JoinRule = spec.Invite + historyVisibilityContent.HistoryVisibility = historyVisibilityShared + for _, invitee := range createRequest.InvitedUsers { + powerLevelContent.Users[invitee] = 100 + } + guestsCanJoin = true + case spec.PresetPublicChat: + joinRuleContent.JoinRule = spec.Public + historyVisibilityContent.HistoryVisibility = historyVisibilityShared + } + + createEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomCreate, + Content: createContent, + } + powerLevelEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomPowerLevels, + Content: powerLevelContent, + } + joinRuleEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomJoinRules, + Content: joinRuleContent, + } + historyVisibilityEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomHistoryVisibility, + Content: historyVisibilityContent, + } + membershipEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomMember, + StateKey: string(senderID), + } + + memberContent := gomatrixserverlib.MemberContent{ + Membership: spec.Join, + DisplayName: createRequest.UserDisplayName, + AvatarURL: createRequest.UserAvatarURL, + } + + // get the signing identity + identity, err := c.Cfg.Matrix.SigningIdentityFor(userID.Domain()) // we MUST use the server signing mxid_mapping + if err != nil { + logrus.WithError(err).WithField("domain", userID.Domain()).Error("unable to find signing identity for domain") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // If we are creating a room with pseudo IDs, create and sign the MXIDMapping + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + var pseudoIDKey ed25519.PrivateKey + pseudoIDKey, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + mapping := &gomatrixserverlib.MXIDMapping{ + UserRoomKey: spec.SenderIDFromPseudoIDKey(pseudoIDKey), + UserID: userID.String(), + } + + // Sign the mapping with the server identity + if err = mapping.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil { + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + memberContent.MXIDMapping = mapping + + // sign all events with the pseudo ID key + identity = &fclient.SigningIdentity{ + ServerName: "self", + KeyID: "ed25519:1", + PrivateKey: pseudoIDKey, + } + } + membershipEvent.Content = memberContent + + var nameEvent *gomatrixserverlib.FledglingEvent + var topicEvent *gomatrixserverlib.FledglingEvent + var guestAccessEvent *gomatrixserverlib.FledglingEvent + var aliasEvent *gomatrixserverlib.FledglingEvent + + if createRequest.RoomName != "" { + nameEvent = &gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomName, + Content: eventutil.NameContent{ + Name: createRequest.RoomName, + }, + } + } + + if createRequest.Topic != "" { + topicEvent = &gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomTopic, + Content: eventutil.TopicContent{ + Topic: createRequest.Topic, + }, + } + } + + if guestsCanJoin { + guestAccessEvent = &gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomGuestAccess, + Content: eventutil.GuestAccessContent{ + GuestAccess: "can_join", + }, + } + } + + var roomAlias string + if createRequest.RoomAliasName != "" { + roomAlias = fmt.Sprintf("#%s:%s", createRequest.RoomAliasName, userID.Domain()) + // check it's free + // TODO: This races but is better than nothing + hasAliasReq := api.GetRoomIDForAliasRequest{ + Alias: roomAlias, + IncludeAppservices: false, + } + + var aliasResp api.GetRoomIDForAliasResponse + err = c.RSAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if aliasResp.RoomID != "" { + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.RoomInUse("Room ID already exists."), + } + } + + aliasEvent = &gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomCanonicalAlias, + Content: eventutil.CanonicalAlias{ + Alias: roomAlias, + }, + } + } + + var initialStateEvents []gomatrixserverlib.FledglingEvent + for i := range createRequest.InitialState { + if createRequest.InitialState[i].StateKey != "" { + initialStateEvents = append(initialStateEvents, createRequest.InitialState[i]) + continue + } + + switch createRequest.InitialState[i].Type { + case spec.MRoomCreate: + continue + + case spec.MRoomPowerLevels: + powerLevelEvent = createRequest.InitialState[i] + + case spec.MRoomJoinRules: + joinRuleEvent = createRequest.InitialState[i] + + case spec.MRoomHistoryVisibility: + historyVisibilityEvent = createRequest.InitialState[i] + + case spec.MRoomGuestAccess: + guestAccessEvent = &createRequest.InitialState[i] + + case spec.MRoomName: + nameEvent = &createRequest.InitialState[i] + + case spec.MRoomTopic: + topicEvent = &createRequest.InitialState[i] + + default: + initialStateEvents = append(initialStateEvents, createRequest.InitialState[i]) + } + } + + // send events into the room in order of: + // 1- m.room.create + // 2- room creator join member + // 3- m.room.power_levels + // 4- m.room.join_rules + // 5- m.room.history_visibility + // 6- m.room.canonical_alias (opt) + // 7- m.room.guest_access (opt) + // 8- other initial state items + // 9- m.room.name (opt) + // 10- m.room.topic (opt) + // 11- invite events (opt) - with is_direct flag if applicable TODO + // 12- 3pid invite events (opt) TODO + // This differs from Synapse slightly. Synapse would vary the ordering of 3-7 + // depending on if those events were in "initial_state" or not. This made it + // harder to reason about, hence sticking to a strict static ordering. + // TODO: Synapse has txn/token ID on each event. Do we need to do this here? + eventsToMake := []gomatrixserverlib.FledglingEvent{ + createEvent, membershipEvent, powerLevelEvent, joinRuleEvent, historyVisibilityEvent, + } + if guestAccessEvent != nil { + eventsToMake = append(eventsToMake, *guestAccessEvent) + } + eventsToMake = append(eventsToMake, initialStateEvents...) + if nameEvent != nil { + eventsToMake = append(eventsToMake, *nameEvent) + } + if topicEvent != nil { + eventsToMake = append(eventsToMake, *topicEvent) + } + if aliasEvent != nil { + // TODO: bit of a chicken and egg problem here as the alias doesn't exist and cannot until we have made the room. + // This means we might fail creating the alias but say the canonical alias is something that doesn't exist. + eventsToMake = append(eventsToMake, *aliasEvent) + } + + // TODO: invite events + // TODO: 3pid invite events + + var builtEvents []*types.HeaderedEvent + authEvents := gomatrixserverlib.NewAuthEvents(nil) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + for i, e := range eventsToMake { + depth := i + 1 // depth starts at 1 + + builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ + SenderID: string(senderID), + RoomID: roomID.String(), + Type: e.Type, + StateKey: &e.StateKey, + Depth: int64(depth), + }) + err = builder.SetContent(e.Content) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if i > 0 { + builder.PrevEvents = []string{builtEvents[i-1].EventID()} + } + var ev gomatrixserverlib.PDU + if err = builder.AddAuthEvents(&authEvents); err != nil { + util.GetLogger(ctx).WithError(err).Error("AddAuthEvents failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + ev, err = builder.Build(createRequest.EventTime, identity.ServerName, identity.KeyID, identity.PrivateKey) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("buildEvent failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // Add the event to the list of auth events + builtEvents = append(builtEvents, &types.HeaderedEvent{PDU: ev}) + err = authEvents.AddEvent(ev) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + inputs := make([]api.InputRoomEvent, 0, len(builtEvents)) + for _, event := range builtEvents { + inputs = append(inputs, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + Origin: userID.Domain(), + SendAsServer: api.DoNotSendToOtherServers, + }) + } + + // send the events to the roomserver + if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs, false); err != nil { + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // TODO(#269): Reserve room alias while we create the room. This stops us + // from creating the room but still failing due to the alias having already + // been taken. + if roomAlias != "" { + aliasReq := api.SetRoomAliasRequest{ + Alias: roomAlias, + RoomID: roomID.String(), + UserID: userID.String(), + } + + var aliasResp api.SetRoomAliasResponse + err = c.RSAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + if aliasResp.AliasExists { + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.RoomInUse("Room alias already exists."), + } + } + } + + // If this is a direct message then we should invite the participants. + if len(createRequest.InvitedUsers) > 0 { + // Build some stripped state for the invite. + var globalStrippedState []gomatrixserverlib.InviteStrippedState + for _, event := range builtEvents { + // Chosen events from the spec: + // https://spec.matrix.org/v1.3/client-server-api/#stripped-state + switch event.Type() { + case spec.MRoomCreate: + fallthrough + case spec.MRoomName: + fallthrough + case spec.MRoomAvatar: + fallthrough + case spec.MRoomTopic: + fallthrough + case spec.MRoomCanonicalAlias: + fallthrough + case spec.MRoomEncryption: + fallthrough + case spec.MRoomMember: + fallthrough + case spec.MRoomJoinRules: + ev := event.PDU + globalStrippedState = append( + globalStrippedState, + gomatrixserverlib.NewInviteStrippedState(ev), + ) + } + } + + // Process the invites. + var inviteEvent *types.HeaderedEvent + for _, invitee := range createRequest.InvitedUsers { + inviteeUserID, userIDErr := spec.NewUserID(invitee, true) + if userIDErr != nil { + util.GetLogger(ctx).WithError(userIDErr).Error("invalid UserID") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID, *inviteeUserID) + if queryErr != nil { + util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + inviteeString := string(inviteeSenderID) + proto := gomatrixserverlib.ProtoEvent{ + SenderID: string(senderID), + RoomID: roomID.String(), + Type: "m.room.member", + StateKey: &inviteeString, + } + + content := gomatrixserverlib.MemberContent{ + Membership: spec.Invite, + DisplayName: createRequest.UserDisplayName, + AvatarURL: createRequest.UserAvatarURL, + Reason: "", + IsDirect: createRequest.IsDirect, + } + + if err = proto.SetContent(content); err != nil { + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // Build the invite event. + inviteEvent, err = eventutil.QueryAndBuildEvent(ctx, &proto, identity, createRequest.EventTime, c.RSAPI, nil) + + if err != nil { + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") + continue + } + inviteStrippedState := append( + globalStrippedState, + gomatrixserverlib.NewInviteStrippedState(inviteEvent.PDU), + ) + // Send the invite event to the roomserver. + event := inviteEvent + err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{ + Event: event, + InviteRoomState: inviteStrippedState, + RoomVersion: event.Version(), + SendAsServer: string(userID.Domain()), + }) + switch e := err.(type) { + case api.ErrInvalidID: + return "", &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown(e.Error()), + } + case api.ErrNotAllowed: + return "", &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden(e.Error()), + } + case nil: + default: + util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") + sentry.CaptureException(err) + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + } + + if createRequest.Visibility == spec.Public { + // expose this room in the published room list + if err = c.RSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ + RoomID: roomID.String(), + Visibility: spec.Public, + }); err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to publish room") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + // TODO: visibility/presets/raw initial state + // TODO: Create room alias association + // Make sure this doesn't fall into an application service's namespace though! + + return roomAlias, nil +} diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 1fb6eb43a4..7fbec37107 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -56,7 +56,7 @@ func (r *InboundPeeker) PerformInboundPeek( response.RoomExists = true response.RoomVersion = info.RoomVersion - var stateEvents []*gomatrixserverlib.Event + var stateEvents []gomatrixserverlib.PDU var currentStateSnapshotNID types.StateSnapshotNID latestEventRefs, currentStateSnapshotNID, _, err := @@ -64,22 +64,22 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0].EventID}) + latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0]}) if err != nil { return err } - var sortedLatestEvents []*gomatrixserverlib.Event + var sortedLatestEvents []gomatrixserverlib.PDU for _, ev := range latestEvents { - sortedLatestEvents = append(sortedLatestEvents, ev.Event) + sortedLatestEvents = append(sortedLatestEvents, ev.PDU) } sortedLatestEvents = gomatrixserverlib.ReverseTopologicalOrdering( sortedLatestEvents, gomatrixserverlib.TopologicalOrderByPrevEvents, ) - response.LatestEvent = sortedLatestEvents[0].Headered(info.RoomVersion) + response.LatestEvent = &types.HeaderedEvent{PDU: sortedLatestEvents[0]} // XXX: do we actually need to do a state resolution here? - roomState := state.NewStateResolution(r.DB, info) + roomState := state.NewStateResolution(r.DB, info, r.Inputer.Queryer) var stateEntries []types.StateEntry stateEntries, err = roomState.LoadStateAtSnapshot( @@ -106,11 +106,11 @@ func (r *InboundPeeker) PerformInboundPeek( } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) + response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event}) } for _, event := range authEvents { - response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(info.RoomVersion)) + response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{PDU: event}) } err = r.Inputer.OutputProducer.ProduceRoomEvents(request.RoomID, []api.OutputEvent{ diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 140ed7c8a6..f19a508a3a 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -28,205 +28,189 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - log "github.com/sirupsen/logrus" ) +type QueryState struct { + storage.Database + querier api.QuerySenderIDAPI +} + +func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) { + return helpers.GetAuthEvents(ctx, q.Database, event.Version(), event, event.AuthEventIDs()) +} + +func (q *QueryState) GetState(ctx context.Context, roomID spec.RoomID, stateWanted []gomatrixserverlib.StateKeyTuple) ([]gomatrixserverlib.PDU, error) { + info, err := q.Database.RoomInfo(ctx, roomID.String()) + if err != nil { + return nil, fmt.Errorf("failed to load RoomInfo: %w", err) + } + if info != nil { + roomState := state.NewStateResolution(q.Database, info, q.querier) + stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( + ctx, info.StateSnapshotNID(), stateWanted, + ) + if err != nil { + return nil, nil + } + stateNIDs := []types.EventNID{} + for _, stateNID := range stateEntries { + stateNIDs = append(stateNIDs, stateNID.EventNID) + } + stateEvents, err := q.Database.Events(ctx, info.RoomVersion, stateNIDs) + if err != nil { + return nil, fmt.Errorf("failed to obtain required events: %w", err) + } + + events := []gomatrixserverlib.PDU{} + for _, event := range stateEvents { + events = append(events, event.PDU) + } + return events, nil + } + + return nil, nil +} + type Inviter struct { DB storage.Database Cfg *config.RoomServer FSAPI federationAPI.RoomserverFederationAPI + RSAPI api.RoomserverInternalAPI Inputer *input.Inputer } -// nolint:gocyclo -func (r *Inviter) PerformInvite( - ctx context.Context, - req *api.PerformInviteRequest, - res *api.PerformInviteResponse, +func (r *Inviter) IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error) { + info, err := r.DB.RoomInfo(ctx, roomID.String()) + if err != nil { + return false, fmt.Errorf("failed to load RoomInfo: %w", err) + } + return (info != nil && !info.IsStub()), nil +} + +func (r *Inviter) StateQuerier() gomatrixserverlib.StateQuerier { + return &QueryState{Database: r.DB} +} + +func (r *Inviter) ProcessInviteMembership( + ctx context.Context, inviteEvent *types.HeaderedEvent, ) ([]api.OutputEvent, error) { var outputUpdates []api.OutputEvent - event := req.Event - if event.StateKey() == nil { - return nil, fmt.Errorf("invite must be a state event") - } - _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()) + var updater *shared.MembershipUpdater + + validRoomID, err := spec.NewRoomID(inviteEvent.RoomID()) if err != nil { - return nil, fmt.Errorf("sender %q is invalid", event.Sender()) + return nil, err } - - roomID := event.RoomID() - targetUserID := *event.StateKey() - info, err := r.DB.RoomInfo(ctx, roomID) + userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey())) if err != nil { - return nil, fmt.Errorf("failed to load RoomInfo: %w", err) + return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } - - _, domain, err := gomatrixserverlib.SplitID('@', targetUserID) + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain()) + if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { + return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) + } + outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{ + EventNID: 0, + PDU: inviteEvent.PDU, + }, outputUpdates, inviteEvent.Version()) if err != nil { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("The user ID %q is invalid!", targetUserID), - } - return nil, nil + return nil, fmt.Errorf("updateToInviteMembership: %w", err) } - isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) - isOriginLocal := r.Cfg.Matrix.IsLocalServerName(senderDomain) - if !isOriginLocal && !isTargetLocal { - res.Error = &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: "The invite must be either from or to a local user", - } - return nil, nil + if err = updater.Commit(); err != nil { + return nil, fmt.Errorf("updater.Commit: %w", err) } + return outputUpdates, nil +} - logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ - "inviter": event.Sender(), - "invitee": *event.StateKey(), - "room_id": roomID, - "event_id": event.EventID(), - }) - logger.WithFields(log.Fields{ - "room_version": req.RoomVersion, - "room_info_exists": info != nil, - "target_local": isTargetLocal, - "origin_local": isOriginLocal, - }).Debug("processing invite event") +// nolint:gocyclo +func (r *Inviter) PerformInvite( + ctx context.Context, + req *api.PerformInviteRequest, +) error { + event := req.Event - inviteState := req.InviteRoomState - if len(inviteState) == 0 && info != nil { - var is []gomatrixserverlib.InviteV2StrippedState - if is, err = buildInviteStrippedState(ctx, r.DB, info, req); err == nil { - inviteState = is - } - } - if len(inviteState) == 0 { - if err = event.SetUnsignedField("invite_room_state", struct{}{}); err != nil { - return nil, fmt.Errorf("event.SetUnsignedField: %w", err) - } - } else { - if err = event.SetUnsignedField("invite_room_state", inviteState); err != nil { - return nil, fmt.Errorf("event.SetUnsignedField: %w", err) - } + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err } - updateMembershipTableManually := func() ([]api.OutputEvent, error) { - var updater *shared.MembershipUpdater - if updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion); err != nil { - return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) - } - outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{ - EventNID: 0, - Event: event.Unwrap(), - }, outputUpdates, req.Event.RoomVersion) - if err != nil { - return nil, fmt.Errorf("updateToInviteMembership: %w", err) - } - if err = updater.Commit(); err != nil { - return nil, fmt.Errorf("updater.Commit: %w", err) - } - logger.Debugf("updated membership to invite and sending invite OutputEvent") - return outputUpdates, nil + sender, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + if err != nil { + return spec.InvalidParam("The sender user ID is invalid") + } + if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { + return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} } - if (info == nil || info.IsStub()) && !isOriginLocal && isTargetLocal { - // The invite came in over federation for a room that we don't know about - // yet. We need to handle this a bit differently to most invites because - // we don't know the room state, therefore the roomserver can't process - // an input event. Instead we will update the membership table with the - // new invite and generate an output event. - return updateMembershipTableManually() + if event.StateKey() == nil || *event.StateKey() == "" { + return fmt.Errorf("invite must be a state event") + } + invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) + if err != nil || invitedUser == nil { + return spec.InvalidParam("Could not find the matching senderID for this user") } + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain()) - var isAlreadyJoined bool - if info != nil { - _, _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) + // If we're inviting a local user, we can generate the needed pseudoID key here. (if needed) + if isTargetLocal { + var roomVersion gomatrixserverlib.RoomVersion + roomVersion, err = r.DB.GetRoomVersion(ctx, event.RoomID()) if err != nil { - return nil, fmt.Errorf("r.DB.GetMembership: %w", err) + return err } - } - if isAlreadyJoined { - // If the user is joined to the room then that takes precedence over this - // invite event. It makes little sense to move a user that is already - // joined to the room into the invite state. - // This could plausibly happen if an invite request raced with a join - // request for a user. For example if a user was invited to a public - // room and they joined the room at the same time as the invite was sent. - // The other way this could plausibly happen is if an invite raced with - // a kick. For example if a user was kicked from a room in error and in - // response someone else in the room re-invited them then it is possible - // for the invite request to race with the leave event so that the - // target receives invite before it learns that it has been kicked. - // There are a few ways this could be plausibly handled in the roomserver. - // 1) Store the invite, but mark it as retired. That will result in the - // permanent rejection of that invite event. So even if the target - // user leaves the room and the invite is retransmitted it will be - // ignored. However a new invite with a new event ID would still be - // accepted. - // 2) Silently discard the invite event. This means that if the event - // was retransmitted at a later date after the target user had left - // the room we would accept the invite. However since we hadn't told - // the sending server that the invite had been discarded it would - // have no reason to attempt to retry. - // 3) Signal the sending server that the user is already joined to the - // room. - // For now we will implement option 2. Since in the abesence of a retry - // mechanism it will be equivalent to option 1, and we don't have a - // signalling mechanism to implement option 3. - res.Error = &api.PerformError{ - Code: api.PerformErrorNotAllowed, - Msg: "User is already joined to room", + + switch roomVersion { + case gomatrixserverlib.RoomVersionPseudoIDs: + _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *invitedUser, *validRoomID) + if err != nil { + return err + } } - logger.Debugf("user already joined") - return nil, nil } - // If the invite originated remotely then we can't send an - // InputRoomEvent for the invite as it will never pass auth checks - // due to lacking room state, but we still need to tell the client - // about the invite so we can accept it, hence we return an output - // event to send to the Sync API. - if !isOriginLocal { - return updateMembershipTableManually() + invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser) + if err != nil { + return fmt.Errorf("failed looking up senderID for invited user") + } + + input := gomatrixserverlib.PerformInviteInput{ + RoomID: *validRoomID, + InviteEvent: event.PDU, + InvitedUser: *invitedUser, + InvitedSenderID: invitedSenderID, + IsTargetLocal: isTargetLocal, + StrippedState: req.InviteRoomState, + MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, + StateQuerier: &QueryState{r.DB, r.RSAPI}, + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, } - - // The invite originated locally. Therefore we have a responsibility to - // try and see if the user is allowed to make this invite. We can't do - // this for invites coming in over federation - we have to take those on - // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs()) + inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) if err != nil { - logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( - "processInviteEvent.checkAuthEvents failed for event", - ) - res.Error = &api.PerformError{ - Msg: err.Error(), - Code: api.PerformErrorNotAllowed, + switch e := err.(type) { + case spec.MatrixError: + if e.ErrCode == spec.ErrorForbidden { + return api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)} + } } - return nil, nil + return err } - // If the invite originated from us and the target isn't local then we - // should try and send the invite over federation first. It might be - // that the remote user doesn't exist, in which case we can give up - // processing here. - if req.SendAsServer != api.DoNotSendToOtherServers && !isTargetLocal { - fsReq := &federationAPI.PerformInviteRequest{ - RoomVersion: req.RoomVersion, - Event: event, - InviteRoomState: inviteState, - } - fsRes := &federationAPI.PerformInviteResponse{} - if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { - res.Error = &api.PerformError{ - Msg: err.Error(), - Code: api.PerformErrorNotAllowed, - } - logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") - return nil, nil + // Use the returned event if there was one (due to federation), otherwise + // send the original invite event to the roomserver. + if inviteEvent == nil { + inviteEvent = event + } + + // if we invited a local user, we can also create a user room key, if it doesn't exist yet. + if isTargetLocal && event.Version() == gomatrixserverlib.RoomVersionPseudoIDs { + _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *invitedUser, *validRoomID) + if err != nil { + return fmt.Errorf("failed to get user room private key: %w", err) } - event = fsRes.Event - logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID()) } // Send the invite event to the roomserver input stream. This will @@ -238,69 +222,18 @@ func (r *Inviter) PerformInvite( InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, - Event: event, - Origin: senderDomain, + Event: &types.HeaderedEvent{PDU: inviteEvent}, + Origin: sender.Domain(), SendAsServer: req.SendAsServer, }, }, } inputRes := &api.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { - return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) + r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) + if err := inputRes.Err(); err != nil { + util.GetLogger(ctx).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") + return api.ErrNotAllowed{Err: err} } - if err = inputRes.Err(); err != nil { - res.Error = &api.PerformError{ - Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), - Code: api.PerformErrorNotAllowed, - } - logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") - } - - // Don't notify the sync api of this event in the same way as a federated invite so the invitee - // gets the invite, as the roomserver will do this when it processes the m.room.member invite. - return outputUpdates, nil -} -func buildInviteStrippedState( - ctx context.Context, - db storage.Database, - info *types.RoomInfo, - input *api.PerformInviteRequest, -) ([]gomatrixserverlib.InviteV2StrippedState, error) { - stateWanted := []gomatrixserverlib.StateKeyTuple{} - // "If they are set on the room, at least the state for m.room.avatar, m.room.canonical_alias, m.room.join_rules, and m.room.name SHOULD be included." - // https://matrix.org/docs/spec/client_server/r0.6.0#m-room-member - for _, t := range []string{ - gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias, - gomatrixserverlib.MRoomJoinRules, gomatrixserverlib.MRoomAvatar, - gomatrixserverlib.MRoomEncryption, gomatrixserverlib.MRoomCreate, - } { - stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{ - EventType: t, - StateKey: "", - }) - } - roomState := state.NewStateResolution(db, info) - stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( - ctx, info.StateSnapshotNID(), stateWanted, - ) - if err != nil { - return nil, err - } - stateNIDs := []types.EventNID{} - for _, stateNID := range stateEntries { - stateNIDs = append(stateNIDs, stateNID.EventNID) - } - stateEvents, err := db.Events(ctx, info, stateNIDs) - if err != nil { - return nil, err - } - inviteState := []gomatrixserverlib.InviteV2StrippedState{ - gomatrixserverlib.NewInviteV2StrippedState(input.Event.Event), - } - stateEvents = append(stateEvents, types.Event{Event: input.Event.Unwrap()}) - for _, event := range stateEvents { - inviteState = append(inviteState, gomatrixserverlib.NewInviteV2StrippedState(event.Event)) - } - return inviteState, nil + return nil } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index fc7ba940c2..c14554640a 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -16,6 +16,7 @@ package perform import ( "context" + "crypto/ed25519" "database/sql" "errors" "fmt" @@ -24,6 +25,9 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -53,50 +57,34 @@ type Joiner struct { func (r *Joiner) PerformJoin( ctx context.Context, req *rsAPI.PerformJoinRequest, - res *rsAPI.PerformJoinResponse, -) error { +) (roomID string, joinedVia spec.ServerName, err error) { logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomIDOrAlias, "user_id": req.UserID, "servers": req.ServerNames, }) logger.Info("User requested to room join") - roomID, joinedVia, err := r.performJoin(context.Background(), req) + roomID, joinedVia, err = r.performJoin(context.Background(), req) if err != nil { logger.WithError(err).Error("Failed to join room") sentry.CaptureException(err) - perr, ok := err.(*rsAPI.PerformError) - if ok { - res.Error = perr - } else { - res.Error = &rsAPI.PerformError{ - Msg: err.Error(), - } - } - return nil + return "", "", err } logger.Info("User joined room successfully") - res.RoomID = roomID - res.JoinedVia = joinedVia - return nil + + return roomID, joinedVia, nil } func (r *Joiner) performJoin( ctx context.Context, req *rsAPI.PerformJoinRequest, -) (string, gomatrixserverlib.ServerName, error) { +) (string, spec.ServerName, error) { _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { - return "", "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrorBadRequest, - Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), - } + return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)} } if !r.Cfg.Matrix.IsLocalServerName(domain) { - return "", "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrorBadRequest, - Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), - } + return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)} } if strings.HasPrefix(req.RoomIDOrAlias, "!") { return r.performJoinRoomByID(ctx, req) @@ -104,16 +92,13 @@ func (r *Joiner) performJoin( if strings.HasPrefix(req.RoomIDOrAlias, "#") { return r.performJoinRoomByAlias(ctx, req) } - return "", "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrorBadRequest, - Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias), - } + return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)} } func (r *Joiner) performJoinRoomByAlias( ctx context.Context, req *rsAPI.PerformJoinRequest, -) (string, gomatrixserverlib.ServerName, error) { +) (string, spec.ServerName, error) { // Get the domain part of the room alias. _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) if err != nil { @@ -163,12 +148,12 @@ func (r *Joiner) performJoinRoomByAlias( return r.performJoinRoomByID(ctx, req) } -// TODO: Break this function up a bit +// TODO: Break this function up a bit & move to GMSL // nolint:gocyclo func (r *Joiner) performJoinRoomByID( ctx context.Context, req *rsAPI.PerformJoinRequest, -) (string, gomatrixserverlib.ServerName, error) { +) (string, spec.ServerName, error) { // The original client request ?server_name=... may include this HS so filter that out so we // don't attempt to make_join with ourselves for i := 0; i < len(req.ServerNames); i++ { @@ -180,55 +165,16 @@ func (r *Joiner) performJoinRoomByID( } // Get the domain part of the room ID. - _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) + roomID, err := spec.NewRoomID(req.RoomIDOrAlias) if err != nil { - return "", "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrorBadRequest, - Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err), - } + return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)} } // If the server name in the room ID isn't ours then it's a // possible candidate for finding the room via federation. Add // it to the list of servers to try. - if !r.Cfg.Matrix.IsLocalServerName(domain) { - req.ServerNames = append(req.ServerNames, domain) - } - - // Prepare the template for the join event. - userID := req.UserID - _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) - if err != nil { - return "", "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrorBadRequest, - Msg: fmt.Sprintf("User ID %q is invalid: %s", userID, err), - } - } - eb := gomatrixserverlib.EventBuilder{ - Type: gomatrixserverlib.MRoomMember, - Sender: userID, - StateKey: &userID, - RoomID: req.RoomIDOrAlias, - Redacts: "", - } - if err = eb.SetUnsigned(struct{}{}); err != nil { - return "", "", fmt.Errorf("eb.SetUnsigned: %w", err) - } - - // It is possible for the request to include some "content" for the - // event. We'll always overwrite the "membership" key, but the rest, - // like "display_name" or "avatar_url", will be kept if supplied. - if req.Content == nil { - req.Content = map[string]interface{}{} - } - req.Content["membership"] = gomatrixserverlib.Join - if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req); aerr != nil { - return "", "", aerr - } else if authorisedVia != "" { - req.Content["join_authorised_via_users_server"] = authorisedVia - } - if err = eb.SetContent(req.Content); err != nil { - return "", "", fmt.Errorf("eb.SetContent: %w", err) + if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) { + req.ServerNames = append(req.ServerNames, roomID.Domain()) } // Force a federated join if we aren't in the room and we've been @@ -243,29 +189,64 @@ func (r *Joiner) performJoinRoomByID( serverInRoom := inRoomRes.IsInRoom forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom + userID, err := spec.NewUserID(req.UserID, true) + if err != nil { + return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)} + } + + // Look up the room NID for the supplied room ID. + var senderID spec.SenderID + checkInvitePending := false + info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias) + if err == nil && info != nil { + switch info.RoomVersion { + case gomatrixserverlib.RoomVersionPseudoIDs: + senderID, err = r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID) + if err == nil { + checkInvitePending = true + } + if senderID == "" { + // create user room key if needed + key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) + if keyErr != nil { + util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed") + return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", keyErr) + } + senderID = spec.SenderIDFromPseudoIDKey(key) + } + default: + checkInvitePending = true + senderID = spec.SenderID(userID.String()) + } + } + + userDomain := userID.Domain() + // Force a federated join if we're dealing with a pending invite // and we aren't in the room. - isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) - if err == nil && !serverInRoom && isInvitePending { - _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) - if ierr != nil { - return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } + if checkInvitePending { + isInvitePending, inviteSender, _, inviteEvent, inviteErr := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) + if inviteErr == nil && !serverInRoom && isInvitePending { + inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender) + if queryErr != nil { + return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) + } - // If we were invited by someone from another server then we can - // assume they are in the room so we can join via them. - if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) { - req.ServerNames = append(req.ServerNames, inviterDomain) - forceFederatedJoin = true - memberEvent := gjson.Parse(string(inviteEvent.JSON())) - // only set unsigned if we've got a content.membership, which we _should_ - if memberEvent.Get("content.membership").Exists() { - req.Unsigned = map[string]interface{}{ - "prev_sender": memberEvent.Get("sender").Str, - "prev_content": map[string]interface{}{ - "is_direct": memberEvent.Get("content.is_direct").Bool(), - "membership": memberEvent.Get("content.membership").Str, - }, + // If we were invited by someone from another server then we can + // assume they are in the room so we can join via them. + if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { + req.ServerNames = append(req.ServerNames, inviter.Domain()) + forceFederatedJoin = true + memberEvent := gjson.Parse(string(inviteEvent.JSON())) + // only set unsigned if we've got a content.membership, which we _should_ + if memberEvent.Get("content.membership").Exists() { + req.Unsigned = map[string]interface{}{ + "prev_sender": memberEvent.Get("sender").Str, + "prev_content": map[string]interface{}{ + "is_direct": memberEvent.Get("content.is_direct").Bool(), + "membership": memberEvent.Get("content.membership").Str, + }, + } } } } @@ -273,9 +254,9 @@ func (r *Joiner) performJoinRoomByID( // If a guest is trying to join a room, check that the room has a m.room.guest_access event if req.IsGuest { - var guestAccessEvent *gomatrixserverlib.HeaderedEvent + var guestAccessEvent *types.HeaderedEvent guestAccess := "forbidden" - guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "") + guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, spec.MRoomGuestAccess, "") if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil { logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'") } @@ -286,16 +267,14 @@ func (r *Joiner) performJoinRoomByID( // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event // is present on the room and has the guest_access value can_join. if guestAccess != "can_join" { - return "", "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrorNotAllowed, - Msg: "Guest access is forbidden", - } + return "", "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("guest access is forbidden")} } } // If we should do a forced federated join then do that. - var joinedVia gomatrixserverlib.ServerName + var joinedVia spec.ServerName if forceFederatedJoin { + // TODO : pseudoIDs - pass through userID here since we don't know what the senderID should be yet joinedVia, err = r.performFederatedJoinRoomByID(ctx, req) return req.RoomIDOrAlias, joinedVia, err } @@ -305,21 +284,81 @@ func (r *Joiner) performJoinRoomByID( // locally on the homeserver. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, userDomain, &eb) - switch err { + var buildRes rsAPI.QueryLatestEventsAndStateResponse + identity, err := r.RSAPI.SigningIdentityFor(ctx, *roomID, *userID) + if err != nil { + return "", "", fmt.Errorf("error joining local room: %q", err) + } + + // at this point we know we have an existing room + if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + var pseudoIDKey ed25519.PrivateKey + pseudoIDKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", "", err + } + + mapping := &gomatrixserverlib.MXIDMapping{ + UserRoomKey: spec.SenderIDFromPseudoIDKey(pseudoIDKey), + UserID: userID.String(), + } + + // Sign the mapping with the server identity + if err = mapping.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil { + return "", "", err + } + req.Content["mxid_mapping"] = mapping + + // sign the event with the pseudo ID key + identity = fclient.SigningIdentity{ + ServerName: "self", + KeyID: "ed25519:1", + PrivateKey: pseudoIDKey, + } + } + + senderIDString := string(senderID) + + // Prepare the template for the join event. + proto := gomatrixserverlib.ProtoEvent{ + Type: spec.MRoomMember, + SenderID: senderIDString, + StateKey: &senderIDString, + RoomID: req.RoomIDOrAlias, + Redacts: "", + } + if err = proto.SetUnsigned(struct{}{}); err != nil { + return "", "", fmt.Errorf("eb.SetUnsigned: %w", err) + } + + // It is possible for the request to include some "content" for the + // event. We'll always overwrite the "membership" key, but the rest, + // like "display_name" or "avatar_url", will be kept if supplied. + if req.Content == nil { + req.Content = map[string]interface{}{} + } + req.Content["membership"] = spec.Join + if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil { + return "", "", aerr + } else if authorisedVia != "" { + req.Content["join_authorised_via_users_server"] = authorisedVia + } + if err = proto.SetContent(req.Content); err != nil { + return "", "", fmt.Errorf("eb.SetContent: %w", err) + } + event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes) + + switch err.(type) { case nil: // The room join is local. Send the new join event into the // roomserver. First of all check that the user isn't already // a member of the room. This is best-effort (as in we won't // fail if we can't find the existing membership) because there // is really no harm in just sending another membership event. - membershipReq := &api.QueryMembershipForUserRequest{ - RoomID: req.RoomIDOrAlias, - UserID: userID, - } membershipRes := &api.QueryMembershipForUserResponse{} - _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes) + _ = r.Queryer.QueryMembershipForSenderID(ctx, *roomID, senderID, membershipRes) // If we haven't already joined the room then send an event // into the room changing our membership status. @@ -328,23 +367,15 @@ func (r *Joiner) performJoinRoomByID( InputRoomEvents: []rsAPI.InputRoomEvent{ { Kind: rsAPI.KindNew, - Event: event.Headered(buildRes.RoomVersion), + Event: event, SendAsServer: string(userDomain), }, }, } inputRes := rsAPI.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { - return "", "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrorNoOperation, - Msg: fmt.Sprintf("InputRoomEvents failed: %s", err), - } - } + r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) if err = inputRes.Err(); err != nil { - return "", "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrorNotAllowed, - Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err), - } + return "", "", rsAPI.ErrNotAllowed{Err: err} } } @@ -352,15 +383,12 @@ func (r *Joiner) performJoinRoomByID( // The room doesn't exist locally. If the room ID looks like it should // be ours then this probably means that we've nuked our database at // some point. - if r.Cfg.Matrix.IsLocalServerName(domain) { + if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) { // If there are no more server names to try then give up here. // Otherwise we'll try a federated join as normal, since it's quite // possible that the room still exists on other servers. if len(req.ServerNames) == 0 { - return "", "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrorNoRoom, - Msg: fmt.Sprintf("room ID %q does not exist", req.RoomIDOrAlias), - } + return "", "", eventutil.ErrRoomNoExists{} } } @@ -383,7 +411,7 @@ func (r *Joiner) performJoinRoomByID( func (r *Joiner) performFederatedJoinRoomByID( ctx context.Context, req *rsAPI.PerformJoinRequest, -) (gomatrixserverlib.ServerName, error) { +) (spec.ServerName, error) { // Try joining by all of the supplied server names. fedReq := fsAPI.PerformJoinRequest{ RoomID: req.RoomIDOrAlias, // the room ID to try and join @@ -395,11 +423,7 @@ func (r *Joiner) performFederatedJoinRoomByID( fedRes := fsAPI.PerformJoinResponse{} r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes) if fedRes.LastError != nil { - return "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrRemote, - Msg: fedRes.LastError.Message, - RemoteCode: fedRes.LastError.Code, - } + return "", fedRes.LastError } return fedRes.JoinedVia, nil } @@ -407,69 +431,12 @@ func (r *Joiner) performFederatedJoinRoomByID( func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( ctx context.Context, joinReq *rsAPI.PerformJoinRequest, + senderID spec.SenderID, ) (string, error) { - req := &api.QueryRestrictedJoinAllowedRequest{ - UserID: joinReq.UserID, - RoomID: joinReq.RoomIDOrAlias, - } - res := &api.QueryRestrictedJoinAllowedResponse{} - if err := r.Queryer.QueryRestrictedJoinAllowed(ctx, req, res); err != nil { - return "", fmt.Errorf("r.Queryer.QueryRestrictedJoinAllowed: %w", err) - } - if !res.Restricted { - return "", nil - } - if !res.Resident { - return "", nil - } - if !res.Allowed { - return "", &rsAPI.PerformError{ - Code: rsAPI.PerformErrorNotAllowed, - Msg: fmt.Sprintf("The join to room %s was not allowed.", joinReq.RoomIDOrAlias), - } - } - return res.AuthorisedVia, nil -} - -func buildEvent( - ctx context.Context, db storage.Database, cfg *config.Global, - senderDomain gomatrixserverlib.ServerName, - builder *gomatrixserverlib.EventBuilder, -) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) { - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) + roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias) if err != nil { - return nil, nil, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) - } - - if len(eventsNeeded.Tuples()) == 0 { - return nil, nil, errors.New("expecting state tuples for event builder, got none") + return "", err } - var queryRes rsAPI.QueryLatestEventsAndStateResponse - err = helpers.QueryLatestEventsAndState(ctx, db, &rsAPI.QueryLatestEventsAndStateRequest{ - RoomID: builder.RoomID, - StateToFetch: eventsNeeded.Tuples(), - }, &queryRes) - if err != nil { - switch err.(type) { - case types.MissingStateError: - // We know something about the room but the state seems to be - // insufficient to actually build a new event, so in effect we - // had might as well treat the room as if it doesn't exist. - return nil, nil, eventutil.ErrRoomNoExists - default: - return nil, nil, fmt.Errorf("QueryLatestEventsAndState: %w", err) - } - } - - identity, err := cfg.SigningIdentityFor(senderDomain) - if err != nil { - return nil, nil, err - } - - ev, err := eventutil.BuildEvent(ctx, builder, cfg, identity, time.Now(), &eventsNeeded, &queryRes) - if err != nil { - return nil, nil, err - } - return ev, &queryRes, nil + return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, senderID) } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 86f1dfaee7..a20896cf74 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -19,15 +19,18 @@ import ( "encoding/json" "fmt" "strings" + "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" @@ -39,6 +42,7 @@ type Leaver struct { Cfg *config.RoomServer DB storage.Database FSAPI fsAPI.RoomserverFederationAPI + RSAPI rsAPI.RoomserverInternalAPI UserAPI userapi.RoomserverUserAPI Inputer *input.Inputer } @@ -49,16 +53,12 @@ func (r *Leaver) PerformLeave( req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, ) ([]api.OutputEvent, error) { - _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID) - } - if !r.Cfg.Matrix.IsLocalServerName(domain) { - return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) + if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) { + return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String()) } logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomID, - "user_id": req.UserID, + "user_id": req.Leaver.String(), }) logger.Info("User requested to leave join") if strings.HasPrefix(req.RoomID, "!") { @@ -78,21 +78,30 @@ func (r *Leaver) performLeaveRoomByID( req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam ) ([]api.OutputEvent, error) { + roomID, err := spec.NewRoomID(req.RoomID) + if err != nil { + return nil, err + } + leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver) + if err != nil { + return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String()) + } + // If there's an invite outstanding for the room then respond to // that. - isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) + isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver) if err == nil && isInvitePending { - _, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser) - if serr != nil { - return nil, fmt.Errorf("sender %q is invalid", senderUser) + sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser) + if serr != nil || sender == nil { + return nil, fmt.Errorf("sender %q has no matching userID", senderUser) } - if !r.Cfg.Matrix.IsLocalServerName(senderDomain) { - return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) + if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { + return r.performFederatedRejectInvite(ctx, req, res, *sender, eventID, leaver) } // check that this is not a "server notice room" accData := &userapi.QueryAccountDataResponse{} if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ - UserID: req.UserID, + UserID: req.Leaver.String(), RoomID: req.RoomID, DataType: "m.tag", }, accData); err != nil { @@ -110,7 +119,7 @@ func (r *Leaver) performLeaveRoomByID( // mimic the returned values from Synapse res.Message = "You cannot reject this invite" res.Code = 403 - return nil, jsonerror.LeaveServerNoticeError() + return nil, spec.LeaveServerNoticeError() } } } @@ -122,13 +131,13 @@ func (r *Leaver) performLeaveRoomByID( RoomID: req.RoomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{ { - EventType: gomatrixserverlib.MRoomMember, - StateKey: req.UserID, + EventType: spec.MRoomMember, + StateKey: string(leaver), }, }, } latestRes := api.QueryLatestEventsAndStateResponse{} - if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil { + if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil { return nil, err } if !latestRes.RoomExists { @@ -137,45 +146,50 @@ func (r *Leaver) performLeaveRoomByID( // Now let's see if the user is in the room. if len(latestRes.StateEvents) == 0 { - return nil, fmt.Errorf("user %q is not a member of room %q", req.UserID, req.RoomID) + return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID) } membership, err := latestRes.StateEvents[0].Membership() if err != nil { return nil, fmt.Errorf("error getting membership: %w", err) } - if membership != gomatrixserverlib.Join && membership != gomatrixserverlib.Invite { - return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.UserID, membership) + if membership != spec.Join && membership != spec.Invite { + return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership) } // Prepare the template for the leave event. - userID := req.UserID - eb := gomatrixserverlib.EventBuilder{ - Type: gomatrixserverlib.MRoomMember, - Sender: userID, - StateKey: &userID, + senderIDString := string(leaver) + proto := gomatrixserverlib.ProtoEvent{ + Type: spec.MRoomMember, + SenderID: senderIDString, + StateKey: &senderIDString, RoomID: req.RoomID, Redacts: "", } - if err = eb.SetContent(map[string]interface{}{"membership": "leave"}); err != nil { + if err = proto.SetContent(map[string]interface{}{"membership": "leave"}); err != nil { return nil, fmt.Errorf("eb.SetContent: %w", err) } - if err = eb.SetUnsigned(struct{}{}); err != nil { + if err = proto.SetUnsigned(struct{}{}); err != nil { return nil, fmt.Errorf("eb.SetUnsigned: %w", err) } - // Get the sender domain. - _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', eb.Sender) - if serr != nil { - return nil, fmt.Errorf("sender %q is invalid", eb.Sender) - } - // We know that the user is in the room at this point so let's build // a leave event. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, senderDomain, &eb) + + validRoomID, err := spec.NewRoomID(req.RoomID) if err != nil { - return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) + return nil, err + } + + var buildRes rsAPI.QueryLatestEventsAndStateResponse + identity, err := r.RSAPI.SigningIdentityFor(ctx, *validRoomID, req.Leaver) + if err != nil { + return nil, fmt.Errorf("SigningIdentityFor: %w", err) + } + event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes) + if err != nil { + return nil, fmt.Errorf("eventutil.QueryAndBuildEvent: %w", err) } // Give our leave event to the roomserver input stream. The @@ -185,16 +199,14 @@ func (r *Leaver) performLeaveRoomByID( InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, - Event: event.Headered(buildRes.RoomVersion), - Origin: senderDomain, - SendAsServer: string(senderDomain), + Event: event, + Origin: req.Leaver.Domain(), + SendAsServer: string(req.Leaver.Domain()), }, }, } inputRes := api.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { - return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) - } + r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) if err = inputRes.Err(); err != nil { return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } @@ -206,21 +218,17 @@ func (r *Leaver) performFederatedRejectInvite( ctx context.Context, req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam - senderUser, eventID string, + inviteSender spec.UserID, eventID string, + leaver spec.SenderID, ) ([]api.OutputEvent, error) { - _, domain, err := gomatrixserverlib.SplitID('@', senderUser) - if err != nil { - return nil, fmt.Errorf("user ID %q invalid: %w", senderUser, err) - } - // Ask the federation sender to perform a federated leave for us. leaveReq := fsAPI.PerformLeaveRequest{ RoomID: req.RoomID, - UserID: req.UserID, - ServerNames: []gomatrixserverlib.ServerName{domain}, + UserID: req.Leaver.String(), + ServerNames: []spec.ServerName{inviteSender.Domain()}, } leaveRes := fsAPI.PerformLeaveResponse{} - if err = r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { + if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { // failures in PerformLeave should NEVER stop us from telling other components like the // sync API that the invite was withdrawn. Otherwise we can end up with stuck invites. util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event") @@ -231,7 +239,7 @@ func (r *Leaver) performFederatedRejectInvite( util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event") } - updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, req.UserID, true, info.RoomVersion) + updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, string(leaver), true, info.RoomVersion) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event") } @@ -254,9 +262,10 @@ func (r *Leaver) performFederatedRejectInvite( { Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ - EventID: eventID, - Membership: "leave", - TargetUserID: req.UserID, + EventID: eventID, + RoomID: req.RoomID, + Membership: "leave", + TargetSenderID: leaver, }, }, }, nil diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go index 8a7a910349..88fa2a431f 100644 --- a/roomserver/internal/perform/perform_peek.go +++ b/roomserver/internal/perform/perform_peek.go @@ -26,12 +26,13 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) type Peeker struct { - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName Cfg *config.RoomServer FSAPI fsAPI.RoomserverFederationAPI DB storage.Database @@ -43,21 +44,8 @@ type Peeker struct { func (r *Peeker) PerformPeek( ctx context.Context, req *api.PerformPeekRequest, - res *api.PerformPeekResponse, -) error { - roomID, err := r.performPeek(ctx, req) - if err != nil { - perr, ok := err.(*api.PerformError) - if ok { - res.Error = perr - } else { - res.Error = &api.PerformError{ - Msg: err.Error(), - } - } - } - res.RoomID = roomID - return nil +) (roomID string, err error) { + return r.performPeek(ctx, req) } func (r *Peeker) performPeek( @@ -67,16 +55,10 @@ func (r *Peeker) performPeek( // FIXME: there's way too much duplication with performJoin _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { - return "", &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), - } + return "", api.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)} } if !r.Cfg.Matrix.IsLocalServerName(domain) { - return "", &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), - } + return "", api.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)} } if strings.HasPrefix(req.RoomIDOrAlias, "!") { return r.performPeekRoomByID(ctx, req) @@ -84,10 +66,7 @@ func (r *Peeker) performPeek( if strings.HasPrefix(req.RoomIDOrAlias, "#") { return r.performPeekRoomByAlias(ctx, req) } - return "", &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias), - } + return "", api.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)} } func (r *Peeker) performPeekRoomByAlias( @@ -97,7 +76,7 @@ func (r *Peeker) performPeekRoomByAlias( // Get the domain part of the room alias. _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) if err != nil { - return "", fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias) + return "", api.ErrInvalidID{Err: fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias)} } req.ServerNames = append(req.ServerNames, domain) @@ -146,10 +125,7 @@ func (r *Peeker) performPeekRoomByID( // Get the domain part of the room ID. _, domain, err := gomatrixserverlib.SplitID('!', roomID) if err != nil { - return "", &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("Room ID %q is invalid: %s", roomID, err), - } + return "", api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", roomID, err)} } // handle federated peeks @@ -168,11 +144,7 @@ func (r *Peeker) performPeekRoomByID( fedRes := fsAPI.PerformOutboundPeekResponse{} _ = r.FSAPI.PerformOutboundPeek(ctx, &fedReq, &fedRes) if fedRes.LastError != nil { - return "", &api.PerformError{ - Code: api.PerformErrRemote, - Msg: fedRes.LastError.Message, - RemoteCode: fedRes.LastError.Code, - } + return "", fedRes.LastError } } @@ -193,17 +165,11 @@ func (r *Peeker) performPeekRoomByID( } if !worldReadable { - return "", &api.PerformError{ - Code: api.PerformErrorNotAllowed, - Msg: "Room is not world-readable", - } + return "", api.ErrNotAllowed{Err: fmt.Errorf("room is not world-readable")} } if ev, _ := r.DB.GetStateEvent(ctx, roomID, "m.room.encryption", ""); ev != nil { - return "", &api.PerformError{ - Code: api.PerformErrorNotAllowed, - Msg: "Cannot peek into an encrypted room", - } + return "", api.ErrNotAllowed{Err: fmt.Errorf("Cannot peek into an encrypted room")} } // TODO: handle federated peeks diff --git a/roomserver/internal/perform/perform_publish.go b/roomserver/internal/perform/perform_publish.go index fbbfc3219f..297a4a1899 100644 --- a/roomserver/internal/perform/perform_publish.go +++ b/roomserver/internal/perform/perform_publish.go @@ -25,16 +25,10 @@ type Publisher struct { DB storage.Database } +// PerformPublish publishes or unpublishes a room from the room directory. Returns a database error, if any. func (r *Publisher) PerformPublish( ctx context.Context, req *api.PerformPublishRequest, - res *api.PerformPublishResponse, ) error { - err := r.DB.PublishRoom(ctx, req.RoomID, req.AppserviceID, req.NetworkID, req.Visibility == "public") - if err != nil { - res.Error = &api.PerformError{ - Msg: err.Error(), - } - } - return nil + return r.DB.PublishRoom(ctx, req.RoomID, req.AppserviceID, req.NetworkID, req.Visibility == "public") } diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 4d714be668..1ea8079d48 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -24,93 +24,58 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type Unpeeker struct { - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName Cfg *config.RoomServer FSAPI fsAPI.RoomserverFederationAPI Inputer *input.Inputer } -// PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationapi. +// PerformUnpeek handles un-peeking matrix rooms, including over federation by talking to the federationapi. func (r *Unpeeker) PerformUnpeek( ctx context.Context, - req *api.PerformUnpeekRequest, - res *api.PerformUnpeekResponse, -) error { - if err := r.performUnpeek(ctx, req); err != nil { - perr, ok := err.(*api.PerformError) - if ok { - res.Error = perr - } else { - res.Error = &api.PerformError{ - Msg: err.Error(), - } - } - } - return nil -} - -func (r *Unpeeker) performUnpeek( - ctx context.Context, - req *api.PerformUnpeekRequest, + roomID, userID, deviceID string, ) error { // FIXME: there's way too much duplication with performJoin - _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { - return &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), - } + return api.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", userID)} } if !r.Cfg.Matrix.IsLocalServerName(domain) { - return &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), - } + return api.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", userID)} } - if strings.HasPrefix(req.RoomID, "!") { - return r.performUnpeekRoomByID(ctx, req) - } - return &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("Room ID %q is invalid", req.RoomID), + if strings.HasPrefix(roomID, "!") { + return r.performUnpeekRoomByID(ctx, roomID, userID, deviceID) } + return api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid", roomID)} } func (r *Unpeeker) performUnpeekRoomByID( _ context.Context, - req *api.PerformUnpeekRequest, + roomID, userID, deviceID string, ) (err error) { // Get the domain part of the room ID. - _, _, err = gomatrixserverlib.SplitID('!', req.RoomID) + _, _, err = gomatrixserverlib.SplitID('!', roomID) if err != nil { - return &api.PerformError{ - Code: api.PerformErrorBadRequest, - Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomID, err), - } + return api.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", roomID, err)} } // TODO: handle federated peeks - - err = r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{ + // By this point, if req.RoomIDOrAlias contained an alias, then + // it will have been overwritten with a room ID by performPeekRoomByAlias. + // We should now include this in the response so that the CS API can + // return the right room ID. + return r.Inputer.OutputProducer.ProduceRoomEvents(roomID, []api.OutputEvent{ { Type: api.OutputTypeRetirePeek, RetirePeek: &api.OutputRetirePeek{ - RoomID: req.RoomID, - UserID: req.UserID, - DeviceID: req.DeviceID, + RoomID: roomID, + UserID: userID, + DeviceID: deviceID, }, }, }) - if err != nil { - return - } - - // By this point, if req.RoomIDOrAlias contained an alias, then - // it will have been overwritten with a room ID by performPeekRoomByAlias. - // We should now include this in the response so that the CS API can - // return the right room ID. - return nil } diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 02a19911c0..32f547dc14 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -22,8 +22,10 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -33,61 +35,43 @@ type Upgrader struct { URSAPI api.RoomserverInternalAPI } -// fledglingEvent is a helper representation of an event used when creating many events in succession. -type fledglingEvent struct { - Type string `json:"type"` - StateKey string `json:"state_key"` - Content interface{} `json:"content"` -} - // PerformRoomUpgrade upgrades a room from one version to another func (r *Upgrader) PerformRoomUpgrade( ctx context.Context, - req *api.PerformRoomUpgradeRequest, - res *api.PerformRoomUpgradeResponse, -) error { - res.NewRoomID, res.Error = r.performRoomUpgrade(ctx, req) - if res.Error != nil { - res.NewRoomID = "" - logrus.WithContext(ctx).WithError(res.Error).Error("Room upgrade failed") - } - return nil + roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion, +) (newRoomID string, err error) { + return r.performRoomUpgrade(ctx, roomID, userID, roomVersion) } func (r *Upgrader) performRoomUpgrade( ctx context.Context, - req *api.PerformRoomUpgradeRequest, -) (string, *api.PerformError) { - roomID := req.RoomID - userID := req.UserID - _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) - if err != nil { - return "", &api.PerformError{ - Code: api.PerformErrorNotAllowed, - Msg: "Error validating the user ID", - } - } + roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion, +) (string, error) { evTime := time.Now() // Return an immediate error if the room does not exist if err := r.validateRoomExists(ctx, roomID); err != nil { - return "", &api.PerformError{ - Code: api.PerformErrorNoRoom, - Msg: "Error validating that the room exists", - } + return "", err + } + + fullRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return "", err + } + senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, *fullRoomID, userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") + return "", err } // 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone) - if !r.userIsAuthorized(ctx, userID, roomID) { - return "", &api.PerformError{ - Code: api.PerformErrorNotAllowed, - Msg: "You don't have permission to upgrade the room, power level too low.", - } + if !r.userIsAuthorized(ctx, senderID, roomID) { + return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")} } // TODO (#267): Check room ID doesn't clash with an existing one, and we // probably shouldn't be using pseudo-random strings, maybe GUIDs? - newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) + newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userID.Domain()) // Get the existing room state for the old room. oldRoomReq := &api.QueryLatestEventsAndStateRequest{ @@ -95,31 +79,29 @@ func (r *Upgrader) performRoomUpgrade( } oldRoomRes := &api.QueryLatestEventsAndStateResponse{} if err := r.URSAPI.QueryLatestEventsAndState(ctx, oldRoomReq, oldRoomRes); err != nil { - return "", &api.PerformError{ - Msg: fmt.Sprintf("Failed to get latest state: %s", err), - } + return "", fmt.Errorf("Failed to get latest state: %s", err) } // Make the tombstone event - tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, userID, roomID, newRoomID) + tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, senderID, userID.Domain(), roomID, newRoomID) if pErr != nil { return "", pErr } // Generate the initial events we need to send into the new room. This includes copied state events and bans // as well as the power level events needed to set up the room - eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, string(req.RoomVersion), tombstoneEvent) + eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, senderID, roomID, roomVersion, tombstoneEvent) if pErr != nil { return "", pErr } // Send the setup events to the new room - if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { + if pErr = r.sendInitialEvents(ctx, evTime, senderID, userID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil { return "", pErr } // 5. Send the tombstone event to the old room - if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil { + if pErr = r.sendHeaderedEvent(ctx, userID.Domain(), tombstoneEvent, string(userID.Domain())); pErr != nil { return "", pErr } @@ -129,39 +111,32 @@ func (r *Upgrader) performRoomUpgrade( } // If the old room had a canonical alias event, it should be deleted in the old room - if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil { + if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, senderID, userID.Domain(), roomID); pErr != nil { return "", pErr } // 4. Move local aliases to the new room - if pErr = moveLocalAliases(ctx, roomID, newRoomID, userID, r.URSAPI); pErr != nil { + if pErr = moveLocalAliases(ctx, roomID, newRoomID, senderID, userID, r.URSAPI); pErr != nil { return "", pErr } // 6. Restrict power levels in the old room - if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil { + if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, senderID, userID.Domain(), roomID); pErr != nil { return "", pErr } return newRoomID, nil } -func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*gomatrixserverlib.PowerLevelContent, *api.PerformError) { +func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*gomatrixserverlib.PowerLevelContent, error) { oldPowerLevelsEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, + EventType: spec.MRoomPowerLevels, StateKey: "", }) - powerLevelContent, err := oldPowerLevelsEvent.PowerLevels() - if err != nil { - util.GetLogger(ctx).WithError(err).Error() - return nil, &api.PerformError{ - Msg: "Power level event was invalid or malformed", - } - } - return powerLevelContent, nil + return oldPowerLevelsEvent.PowerLevels() } -func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError { +func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error { restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) if pErr != nil { return pErr @@ -178,61 +153,53 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T restrictedPowerLevelContent.EventsDefault = restrictedDefaultPowerLevel restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel - restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, fledglingEvent{ - Type: gomatrixserverlib.MRoomPowerLevels, + restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomPowerLevels, StateKey: "", Content: restrictedPowerLevelContent, }) - if resErr != nil { - if resErr.Code == api.PerformErrorNotAllowed { - util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not restrict power levels in old room") - } else { - return resErr - } - } else { - if resErr = r.sendHeaderedEvent(ctx, userDomain, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil { - return resErr - } + + switch resErr.(type) { + case api.ErrNotAllowed: + util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not restrict power levels in old room") + case nil: + return r.sendHeaderedEvent(ctx, userDomain, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers) + default: + return resErr } return nil } func moveLocalAliases(ctx context.Context, - roomID, newRoomID, userID string, - URSAPI api.RoomserverInternalAPI) *api.PerformError { - var err error + roomID, newRoomID string, senderID spec.SenderID, userID spec.UserID, + URSAPI api.RoomserverInternalAPI, +) (err error) { aliasReq := api.GetAliasesForRoomIDRequest{RoomID: roomID} aliasRes := api.GetAliasesForRoomIDResponse{} if err = URSAPI.GetAliasesForRoomID(ctx, &aliasReq, &aliasRes); err != nil { - return &api.PerformError{ - Msg: fmt.Sprintf("Failed to get old room aliases: %s", err), - } + return fmt.Errorf("Failed to get old room aliases: %w", err) } for _, alias := range aliasRes.Aliases { - removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias} + removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias} removeAliasRes := api.RemoveRoomAliasResponse{} if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { - return &api.PerformError{ - Msg: fmt.Sprintf("Failed to remove old room alias: %s", err), - } + return fmt.Errorf("Failed to remove old room alias: %w", err) } - setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID} + setAliasReq := api.SetRoomAliasRequest{UserID: userID.String(), Alias: alias, RoomID: newRoomID} setAliasRes := api.SetRoomAliasResponse{} if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil { - return &api.PerformError{ - Msg: fmt.Sprintf("Failed to set new room alias: %s", err), - } + return fmt.Errorf("Failed to set new room alias: %w", err) } } return nil } -func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError { +func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error { for _, event := range oldRoom.StateEvents { - if event.Type() != gomatrixserverlib.MRoomCanonicalAlias || !event.StateKeyEquals("") { + if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") { continue } var aliasContent struct { @@ -240,9 +207,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api AltAliases []string `json:"alt_aliases"` } if err := json.Unmarshal(event.Content(), &aliasContent); err != nil { - return &api.PerformError{ - Msg: fmt.Sprintf("Failed to unmarshal canonical aliases: %s", err), - } + return fmt.Errorf("failed to unmarshal canonical aliases: %w", err) } if aliasContent.Alias == "" && len(aliasContent.AltAliases) == 0 { // There are no canonical aliases to clear, therefore do nothing. @@ -250,34 +215,29 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api } } - emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, fledglingEvent{ - Type: gomatrixserverlib.MRoomCanonicalAlias, + emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomCanonicalAlias, Content: map[string]interface{}{}, }) - if resErr != nil { - if resErr.Code == api.PerformErrorNotAllowed { - util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not set empty canonical alias event in old room") - } else { - return resErr - } - } else { - if resErr = r.sendHeaderedEvent(ctx, userDomain, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil { - return resErr - } + switch resErr.(type) { + case api.ErrNotAllowed: + util.GetLogger(ctx).WithField(logrus.ErrorKey, resErr).Warn("UpgradeRoom: Could not set empty canonical alias event in old room") + case nil: + return r.sendHeaderedEvent(ctx, userDomain, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers) + default: + return resErr } return nil } -func (r *Upgrader) publishIfOldRoomWasPublic(ctx context.Context, roomID, newRoomID string) *api.PerformError { +func (r *Upgrader) publishIfOldRoomWasPublic(ctx context.Context, roomID, newRoomID string) error { // check if the old room was published var pubQueryRes api.QueryPublishedRoomsResponse err := r.URSAPI.QueryPublishedRooms(ctx, &api.QueryPublishedRoomsRequest{ RoomID: roomID, }, &pubQueryRes) if err != nil { - return &api.PerformError{ - Msg: "QueryPublishedRooms failed", - } + return err } // if the old room is published (was public), publish the new room @@ -293,46 +253,35 @@ func publishNewRoomAndUnpublishOldRoom( oldRoomID, newRoomID string, ) { // expose this room in the published room list - var pubNewRoomRes api.PerformPublishResponse if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ RoomID: newRoomID, - Visibility: "public", - }, &pubNewRoomRes); err != nil { - util.GetLogger(ctx).WithError(err).Error("failed to reach internal API") - } else if pubNewRoomRes.Error != nil { + Visibility: spec.Public, + }); err != nil { // treat as non-fatal since the room is already made by this point - util.GetLogger(ctx).WithError(pubNewRoomRes.Error).Error("failed to visibility:public") + util.GetLogger(ctx).WithError(err).Error("failed to publish room") } - var unpubOldRoomRes api.PerformPublishResponse // remove the old room from the published room list if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ RoomID: oldRoomID, Visibility: "private", - }, &unpubOldRoomRes); err != nil { - util.GetLogger(ctx).WithError(err).Error("failed to reach internal API") - } else if unpubOldRoomRes.Error != nil { + }); err != nil { // treat as non-fatal since the room is already made by this point - util.GetLogger(ctx).WithError(unpubOldRoomRes.Error).Error("failed to visibility:private") + util.GetLogger(ctx).WithError(err).Error("failed to un-publish room") } } func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error { - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := r.URSAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - return &api.PerformError{ - Code: api.PerformErrorNoRoom, - Msg: "Room does not exist", - } + if _, err := r.URSAPI.QueryRoomVersionForRoom(ctx, roomID); err != nil { + return eventutil.ErrRoomNoExists{} } return nil } -func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, +func (r *Upgrader) userIsAuthorized(ctx context.Context, senderID spec.SenderID, roomID string, ) bool { plEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, + EventType: spec.MRoomPowerLevels, StateKey: "", }) if plEvent == nil { @@ -344,19 +293,19 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string, } // Check for power level required to send tombstone event (marks the current room as obsolete), // if not found, use the StateDefault power level - return pl.UserLevel(userID) >= pl.EventLevel("m.room.tombstone", true) + return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true) } // nolint:gocyclo -func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID, newVersion string, tombstoneEvent *gomatrixserverlib.HeaderedEvent) ([]fledglingEvent, *api.PerformError) { - state := make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(oldRoom.StateEvents)) +func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, senderID spec.SenderID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) { + state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents)) for _, event := range oldRoom.StateEvents { if event.StateKey() == nil { // This shouldn't ever happen, but better to be safe than sorry. continue } - if event.Type() == gomatrixserverlib.MRoomMember && !event.StateKeyEquals(userID) { - // With the exception of bans and invites which we do want to copy, we + if event.Type() == spec.MRoomMember && !event.StateKeyEquals(string(senderID)) { + // With the exception of bans which we do want to copy, we // should ignore membership events that aren't our own, as event auth will // prevent us from being able to create membership events on behalf of other // users anyway unless they are invites or bans. @@ -365,52 +314,55 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query continue } switch membership { - case gomatrixserverlib.Ban: - case gomatrixserverlib.Invite: + case spec.Ban: default: continue } } + // skip events that rely on a specific user being present + // TODO: What to do here for pseudoIDs? It's checking non-member events for state keys with userIDs. + sKey := *event.StateKey() + if event.Type() != spec.MRoomMember && len(sKey) > 0 && sKey[:1] == "@" { + continue + } state[gomatrixserverlib.StateKeyTuple{EventType: event.Type(), StateKey: *event.StateKey()}] = event } // The following events are ones that we are going to override manually // in the following section. override := map[gomatrixserverlib.StateKeyTuple]struct{}{ - {EventType: gomatrixserverlib.MRoomCreate, StateKey: ""}: {}, - {EventType: gomatrixserverlib.MRoomMember, StateKey: userID}: {}, - {EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}: {}, - {EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""}: {}, + {EventType: spec.MRoomCreate, StateKey: ""}: {}, + {EventType: spec.MRoomMember, StateKey: string(senderID)}: {}, + {EventType: spec.MRoomPowerLevels, StateKey: ""}: {}, + {EventType: spec.MRoomJoinRules, StateKey: ""}: {}, } // The overridden events are essential events that must be present in the // old room state. Check that they are there. for tuple := range override { if _, ok := state[tuple]; !ok { - return nil, &api.PerformError{ - Msg: fmt.Sprintf("Essential event of type %q state key %q is missing", tuple.EventType, tuple.StateKey), - } + return nil, fmt.Errorf("essential event of type %q state key %q is missing", tuple.EventType, tuple.StateKey) } } - oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCreate, StateKey: ""}] - oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomMember, StateKey: userID}] - oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}] - oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""}] + oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate, StateKey: ""}] + oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: string(senderID)}] + oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomPowerLevels, StateKey: ""}] + oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""}] // Create the new room create event. Using a map here instead of CreateContent // means that we preserve any other interesting fields that might be present // in the create event (such as for the room types MSC). newCreateContent := map[string]interface{}{} _ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent) - newCreateContent["creator"] = userID + newCreateContent["creator"] = string(senderID) newCreateContent["room_version"] = newVersion newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{ EventID: tombstoneEvent.EventID(), RoomID: roomID, } - newCreateEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomCreate, + newCreateEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomCreate, StateKey: "", Content: newCreateContent, } @@ -421,10 +373,10 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // the events after it. newMembershipContent := map[string]interface{}{} _ = json.Unmarshal(oldMembershipEvent.Content(), &newMembershipContent) - newMembershipContent["membership"] = gomatrixserverlib.Join - newMembershipEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomMember, - StateKey: userID, + newMembershipContent["membership"] = spec.Join + newMembershipEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomMember, + StateKey: string(senderID), Content: newMembershipContent, } @@ -436,27 +388,26 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query powerLevelContent, err := oldPowerLevelsEvent.PowerLevels() if err != nil { util.GetLogger(ctx).WithError(err).Error() - return nil, &api.PerformError{ - Msg: "Power level event content was invalid", - } + return nil, fmt.Errorf("Power level event content was invalid") } - tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, userID) + + tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID) // Now do the join rules event, same as the create and membership // events. We'll set a sane default of "invite" so that if the // existing join rules contains garbage, the room can still be // upgraded. newJoinRulesContent := map[string]interface{}{ - "join_rule": gomatrixserverlib.Invite, // sane default + "join_rule": spec.Invite, // sane default } _ = json.Unmarshal(oldJoinRulesEvent.Content(), &newJoinRulesContent) - newJoinRulesEvent := fledglingEvent{ - Type: gomatrixserverlib.MRoomJoinRules, + newJoinRulesEvent := gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomJoinRules, StateKey: "", Content: newJoinRulesContent, } - eventsToMake := make([]fledglingEvent, 0, len(state)) + eventsToMake := make([]gomatrixserverlib.FledglingEvent, 0, len(state)) eventsToMake = append( eventsToMake, newCreateEvent, newMembershipEvent, tempPowerLevelsEvent, newJoinRulesEvent, @@ -464,9 +415,9 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // For some reason Sytest expects there to be a guest access event. // Create one if it doesn't exist. - if _, ok := state[gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomGuestAccess, StateKey: ""}]; !ok { - eventsToMake = append(eventsToMake, fledglingEvent{ - Type: gomatrixserverlib.MRoomGuestAccess, + if _, ok := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomGuestAccess, StateKey: ""}]; !ok { + eventsToMake = append(eventsToMake, gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomGuestAccess, Content: map[string]string{ "guest_access": "forbidden", }, @@ -480,7 +431,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // are already in `eventsToMake`. continue } - newEvent := fledglingEvent{ + newEvent := gomatrixserverlib.FledglingEvent{ Type: tuple.EventType, StateKey: tuple.StateKey, } @@ -494,58 +445,64 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query // If we sent a temporary power level event into the room before, // override that now by restoring the original power levels. if powerLevelsOverridden { - eventsToMake = append(eventsToMake, fledglingEvent{ - Type: gomatrixserverlib.MRoomPowerLevels, + eventsToMake = append(eventsToMake, gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomPowerLevels, Content: powerLevelContent, }) } return eventsToMake, nil } -func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { +func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error { var err error - var builtEvents []*gomatrixserverlib.HeaderedEvent + var builtEvents []*types.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) for i, e := range eventsToMake { depth := i + 1 // depth starts at 1 - builder := gomatrixserverlib.EventBuilder{ - Sender: userID, + proto := gomatrixserverlib.ProtoEvent{ + SenderID: string(senderID), RoomID: newRoomID, Type: e.Type, StateKey: &e.StateKey, Depth: int64(depth), } - err = builder.SetContent(e.Content) + err = proto.SetContent(e.Content) if err != nil { - return &api.PerformError{ - Msg: fmt.Sprintf("Failed to set content of new %q event: %s", builder.Type, err), - } + return fmt.Errorf("failed to set content of new %q event: %w", proto.Type, err) } if i > 0 { - builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} + proto.PrevEvents = []string{builtEvents[i-1].EventID()} } - var event *gomatrixserverlib.Event - event, err = r.buildEvent(&builder, userDomain, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion)) + + var verImpl gomatrixserverlib.IRoomVersion + verImpl, err = gomatrixserverlib.GetRoomVersion(newVersion) if err != nil { - return &api.PerformError{ - Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err), - } + return err + } + builder := verImpl.NewEventBuilderFromProtoEvent(&proto) + if err = builder.AddAuthEvents(&authEvents); err != nil { + return err } - if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { - return &api.PerformError{ - Msg: fmt.Sprintf("Failed to auth new %q event: %s", builder.Type, err), - } + var event gomatrixserverlib.PDU + event, err = builder.Build(evTime, userDomain, r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey) + if err != nil { + return fmt.Errorf("failed to build new %q event: %w", builder.Type, err) + + } + + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { + return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) } // Add the event to the list of auth events - builtEvents = append(builtEvents, event.Headered(gomatrixserverlib.RoomVersion(newVersion))) + builtEvents = append(builtEvents, &types.HeaderedEvent{PDU: event}) err = authEvents.AddEvent(event) if err != nil { - return &api.PerformError{ - Msg: fmt.Sprintf("Failed to add new %q event to auth set: %s", builder.Type, err), - } + return fmt.Errorf("failed to add new %q event to auth set: %w", builder.Type, err) } } @@ -559,9 +516,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user }) } if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil { - return &api.PerformError{ - Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err), - } + return fmt.Errorf("failed to send new room %q to roomserver: %w", newRoomID, err) } return nil } @@ -569,87 +524,65 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user func (r *Upgrader) makeTombstoneEvent( ctx context.Context, evTime time.Time, - userID, roomID, newRoomID string, -) (*gomatrixserverlib.HeaderedEvent, *api.PerformError) { + senderID spec.SenderID, senderDomain spec.ServerName, roomID, newRoomID string, +) (*types.HeaderedEvent, error) { content := map[string]interface{}{ "body": "This room has been replaced", "replacement_room": newRoomID, } - event := fledglingEvent{ + event := gomatrixserverlib.FledglingEvent{ Type: "m.room.tombstone", Content: content, } - return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event) + return r.makeHeaderedEvent(ctx, evTime, senderID, senderDomain, roomID, event) } -func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event fledglingEvent) (*gomatrixserverlib.HeaderedEvent, *api.PerformError) { - builder := gomatrixserverlib.EventBuilder{ - Sender: userID, +func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, senderID spec.SenderID, senderDomain spec.ServerName, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) { + proto := gomatrixserverlib.ProtoEvent{ + SenderID: string(senderID), RoomID: roomID, Type: event.Type, StateKey: &event.StateKey, } - err := builder.SetContent(event.Content) + err := proto.SetContent(event.Content) if err != nil { - return nil, &api.PerformError{ - Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err), - } + return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err) } // Get the sender domain. - _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', builder.Sender) - if serr != nil { - return nil, &api.PerformError{ - Msg: fmt.Sprintf("Failed to split user ID %q: %s", builder.Sender, err), - } - } identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) if err != nil { - return nil, &api.PerformError{ - Msg: fmt.Sprintf("Failed to get signing identity for %q: %s", senderDomain, err), - } + return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err) } var queryRes api.QueryLatestEventsAndStateResponse - headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes) - if err == eventutil.ErrRoomNoExists { - return nil, &api.PerformError{ - Code: api.PerformErrorNoRoom, - Msg: "Room does not exist", - } - } else if e, ok := err.(gomatrixserverlib.BadJSONError); ok { - return nil, &api.PerformError{ - Msg: e.Error(), - } - } else if e, ok := err.(gomatrixserverlib.EventValidationError); ok { - if e.Code == gomatrixserverlib.EventValidationTooLarge { - return nil, &api.PerformError{ - Msg: e.Error(), - } - } - return nil, &api.PerformError{ - Msg: e.Error(), - } - } else if err != nil { - return nil, &api.PerformError{ - Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err), - } + headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, evTime, r.URSAPI, &queryRes) + switch e := err.(type) { + case nil: + case eventutil.ErrRoomNoExists: + return nil, e + case gomatrixserverlib.BadJSONError: + return nil, e + case gomatrixserverlib.EventValidationError: + return nil, e + default: + return nil, fmt.Errorf("failed to build new %q event: %w", proto.Type, err) } + // check to see if this user can perform this operation - stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) + stateEvents := make([]gomatrixserverlib.PDU, len(queryRes.StateEvents)) for i := range queryRes.StateEvents { - stateEvents[i] = queryRes.StateEvents[i].Event + stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.Event, &provider); err != nil { - return nil, &api.PerformError{ - Code: api.PerformErrorNotAllowed, - Msg: fmt.Sprintf("Failed to auth new %q event: %s", builder.Type, err), // TODO: Is this error string comprehensible to the client? - } + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { + return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? } return headeredEvent, nil } -func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelContent, userID string) (fledglingEvent, bool) { +func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelContent, senderID spec.SenderID) (gomatrixserverlib.FledglingEvent, bool) { // Work out what power level we need in order to be able to send events // of all types into the room. neededPowerLevel := powerLevelContent.StateDefault @@ -674,24 +607,24 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC // If the user who is upgrading the room doesn't already have sufficient // power, then elevate their power levels. - if tempPowerLevelContent.UserLevel(userID) < neededPowerLevel { - tempPowerLevelContent.Users[userID] = neededPowerLevel + if tempPowerLevelContent.UserLevel(senderID) < neededPowerLevel { + tempPowerLevelContent.Users[string(senderID)] = neededPowerLevel powerLevelsOverridden = true } // Then return the temporary power levels event. - return fledglingEvent{ - Type: gomatrixserverlib.MRoomPowerLevels, + return gomatrixserverlib.FledglingEvent{ + Type: spec.MRoomPowerLevels, Content: tempPowerLevelContent, }, powerLevelsOverridden } func (r *Upgrader) sendHeaderedEvent( ctx context.Context, - serverName gomatrixserverlib.ServerName, - headeredEvent *gomatrixserverlib.HeaderedEvent, + serverName spec.ServerName, + headeredEvent *types.HeaderedEvent, sendAsServer string, -) *api.PerformError { +) error { var inputs []api.InputRoomEvent inputs = append(inputs, api.InputRoomEvent{ Kind: api.KindNew, @@ -699,37 +632,5 @@ func (r *Upgrader) sendHeaderedEvent( Origin: serverName, SendAsServer: sendAsServer, }) - if err := api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false); err != nil { - return &api.PerformError{ - Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err), - } - } - - return nil -} - -func (r *Upgrader) buildEvent( - builder *gomatrixserverlib.EventBuilder, - serverName gomatrixserverlib.ServerName, - provider gomatrixserverlib.AuthEventProvider, - evTime time.Time, - roomVersion gomatrixserverlib.RoomVersion, -) (*gomatrixserverlib.Event, error) { - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) - if err != nil { - return nil, err - } - refs, err := eventsNeeded.AuthEventReferences(provider) - if err != nil { - return nil, err - } - builder.AuthEvents = refs - event, err := builder.Build( - evTime, serverName, r.Cfg.Matrix.KeyID, - r.Cfg.Matrix.PrivateKey, roomVersion, - ) - if err != nil { - return nil, err - } - return event, nil + return api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index e1b292034d..626d3c13ef 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -16,12 +16,15 @@ package query import ( "context" + "crypto/ed25519" "database/sql" - "encoding/json" "errors" "fmt" + //"github.com/matrix-org/dendrite/roomserver/internal" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -41,8 +44,44 @@ import ( type Queryer struct { DB storage.Database Cache caching.RoomServerCaches - IsLocalServerName func(gomatrixserverlib.ServerName) bool + IsLocalServerName func(spec.ServerName) bool ServerACLs *acls.ServerACLs + Cfg *config.Dendrite +} + +func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { + roomInfo, err := r.QueryRoomInfo(ctx, roomID) + if err != nil || roomInfo == nil || roomInfo.IsStub() { + return nil, err + } + + req := api.QueryServerJoinedToRoomRequest{ + ServerName: localServerName, + RoomID: roomID.String(), + } + res := api.QueryServerJoinedToRoomResponse{} + if err = r.QueryServerJoinedToRoom(ctx, &req, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) + } + + userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + locallyJoinedUsers, err := r.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID)) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + return &gomatrixserverlib.RestrictedRoomJoinInfo{ + LocalServerInRoom: res.RoomExists && res.IsInRoom, + UserJoinedToRoom: userJoinedToRoom, + JoinedUsers: locallyJoinedUsers, + }, nil } // QueryLatestEventsAndState implements api.RoomserverInternalAPI @@ -51,7 +90,7 @@ func (r *Queryer) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response) + return helpers.QueryLatestEventsAndState(ctx, r.DB, r, request, response) } // QueryStateAfterEvents implements api.RoomserverInternalAPI @@ -68,7 +107,7 @@ func (r *Queryer) QueryStateAfterEvents( return nil } - roomState := state.NewStateResolution(r.DB, info) + roomState := state.NewStateResolution(r.DB, info, r) response.RoomExists = true response.RoomVersion = info.RoomVersion @@ -120,14 +159,18 @@ func (r *Queryer) QueryStateAfterEvents( return fmt.Errorf("getAuthChain: %w", err) } - stateEvents, err = gomatrixserverlib.ResolveConflicts(info.RoomVersion, stateEvents, authEvents) + stateEvents, err = gomatrixserverlib.ResolveConflicts( + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) + }, + ) if err != nil { return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) } } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) + response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event}) } return nil @@ -172,19 +215,20 @@ func (r *Queryer) QueryEventsByID( } for _, event := range events { - response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion)) + response.Events = append(response.Events, &types.HeaderedEvent{PDU: event.PDU}) } return nil } -// QueryMembershipForUser implements api.RoomserverInternalAPI -func (r *Queryer) QueryMembershipForUser( +// QueryMembershipForSenderID implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipForSenderID( ctx context.Context, - request *api.QueryMembershipForUserRequest, + roomID spec.RoomID, + senderID spec.SenderID, response *api.QueryMembershipForUserResponse, ) error { - info, err := r.DB.RoomInfo(ctx, request.RoomID) + info, err := r.DB.RoomInfo(ctx, roomID.String()) if err != nil { return err } @@ -194,16 +238,11 @@ func (r *Queryer) QueryMembershipForUser( } response.RoomExists = true - membershipEventNID, membershipState, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID) if err != nil { return err } - if membershipState == tables.MembershipStateInvite { - response.Membership = gomatrixserverlib.Invite - response.IsInRoom = true - } - response.IsRoomForgotten = isRoomforgotten if membershipEventNID == 0 { @@ -214,7 +253,7 @@ func (r *Queryer) QueryMembershipForUser( response.IsInRoom = stillInRoom response.HasBeenInRoom = true - evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID}) + evs, err := r.DB.Events(ctx, info.RoomVersion, []types.EventNID{membershipEventNID}) if err != nil { return err } @@ -227,6 +266,24 @@ func (r *Queryer) QueryMembershipForUser( return err } +// QueryMembershipForUser implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipForUser( + ctx context.Context, + request *api.QueryMembershipForUserRequest, + response *api.QueryMembershipForUserResponse, +) error { + roomID, err := spec.NewRoomID(request.RoomID) + if err != nil { + return err + } + senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID) + if err != nil { + return err + } + + return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response) +} + // QueryMembershipAtEvent returns the known memberships at a given event. // If the state before an event is not known, an empty list will be returned // for that event instead. @@ -235,7 +292,7 @@ func (r *Queryer) QueryMembershipAtEvent( request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse, ) error { - response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) + response.Membership = make(map[string]*types.HeaderedEvent) info, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { @@ -263,8 +320,8 @@ func (r *Queryer) QueryMembershipAtEvent( return err } - response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) - stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) + response.Membership = make(map[string]*types.HeaderedEvent) + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r) if err != nil { return fmt.Errorf("unable to get state before event: %w", err) } @@ -310,8 +367,8 @@ func (r *Queryer) QueryMembershipAtEvent( // a given event, overwrite any other existing membership events. for i := range memberships { ev := memberships[i] - if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { - response.Membership[eventID] = ev.Event.Headered(info.RoomVersion) + if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) { + response.Membership[eventID] = &types.HeaderedEvent{PDU: ev.PDU} } } } @@ -336,7 +393,7 @@ func (r *Queryer) QueryMembershipsForRoom( // If no sender is specified then we will just return the entire // set of memberships for the room, regardless of whether a specific // user is allowed to see them or not. - if request.Sender == "" { + if request.SenderID == "" { var events []types.Event var eventNIDs []types.EventNID eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly) @@ -346,18 +403,20 @@ func (r *Queryer) QueryMembershipsForRoom( } return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } - events, err = r.DB.Events(ctx, info, eventNIDs) + events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs) if err != nil { return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.Event, synctypes.FormatAll) + clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) + }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil } - membershipEventNID, _, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID) if err != nil { return err } @@ -385,9 +444,9 @@ func (r *Queryer) QueryMembershipsForRoom( return err } - events, err = r.DB.Events(ctx, info, eventNIDs) + events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs) } else { - stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) + stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID, r) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err @@ -400,7 +459,9 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.Event, synctypes.FormatAll) + clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) + }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -417,6 +478,9 @@ func (r *Queryer) QueryServerJoinedToRoom( if err != nil { return fmt.Errorf("r.DB.RoomInfo: %w", err) } + if info != nil { + response.RoomVersion = info.RoomVersion + } if info == nil || info.IsStub() { return nil } @@ -440,8 +504,9 @@ func (r *Queryer) QueryServerJoinedToRoom( // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI func (r *Queryer) QueryServerAllowedToSeeEvent( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, eventID string, + roomID string, ) (allowed bool, err error) { events, err := r.DB.EventNIDs(ctx, []string{eventID}) if err != nil { @@ -471,7 +536,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( } return helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, info, eventID, serverName, isInRoom, + ctx, r.DB, info, roomID, eventID, serverName, isInRoom, r, ) } @@ -512,7 +577,7 @@ func (r *Queryer) QueryMissingEvents( return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID) } - resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) + resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r) if err != nil { return err } @@ -522,17 +587,13 @@ func (r *Queryer) QueryMissingEvents( return err } - response.Events = make([]*gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) + response.Events = make([]*types.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) for _, event := range loadedEvents { if !eventsToFilter[event.EventID()] { - roomVersion, verr := r.roomVersion(event.RoomID()) - if verr != nil { - return verr - } if _, ok := redactEventIDs[event.EventID()]; ok { event.Redact() } - response.Events = append(response.Events, event.Headered(roomVersion)) + response.Events = append(response.Events, &types.HeaderedEvent{PDU: event}) } } @@ -559,18 +620,18 @@ func (r *Queryer) QueryStateAndAuthChain( // the entire current state of the room // TODO: this probably means it should be a different query operation... if request.OnlyFetchAuthChain { - var authEvents []*gomatrixserverlib.Event + var authEvents []gomatrixserverlib.PDU authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs) if err != nil { return err } for _, event := range authEvents { - response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(info.RoomVersion)) + response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{PDU: event}) } return nil } - var stateEvents []*gomatrixserverlib.Event + var stateEvents []gomatrixserverlib.PDU stateEvents, rejected, stateMissing, err := r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs) if err != nil { return err @@ -593,27 +654,30 @@ func (r *Queryer) QueryStateAndAuthChain( } if request.ResolveState { - if stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, stateEvents, authEvents, - ); err != nil { + stateEvents, err = gomatrixserverlib.ResolveConflicts( + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) + }, + ) + if err != nil { return err } } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) + response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event}) } for _, event := range authEvents { - response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(info.RoomVersion)) + response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{PDU: event}) } return err } // first bool: is rejected, second bool: state missing -func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, bool, bool, error) { - roomState := state.NewStateResolution(r.DB, roomInfo) +func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) { + roomState := state.NewStateResolution(r.DB, roomInfo, r) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { @@ -654,13 +718,13 @@ type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Eve // given events. Will *not* error if we don't have all auth events. func GetAuthChain( ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string, -) ([]*gomatrixserverlib.Event, error) { +) ([]gomatrixserverlib.PDU, error) { // List of event IDs to fetch. On each pass, these events will be requested // from the database and the `eventsToFetch` will be updated with any new // events that we have learned about and need to find. When `eventsToFetch` // is eventually empty, we should have reached the end of the chain. eventsToFetch := authEventIDs - authEventsMap := make(map[string]*gomatrixserverlib.Event) + authEventsMap := make(map[string]gomatrixserverlib.PDU) for len(eventsToFetch) > 0 { // Try to retrieve the events from the database. @@ -676,14 +740,14 @@ func GetAuthChain( for _, event := range events { // Store the event in the event map - this prevents us from requesting it // from the database again. - authEventsMap[event.EventID()] = event.Event + authEventsMap[event.EventID()] = event.PDU // Extract all of the auth events from the newly obtained event. If we // don't already have a record of the event, record it in the list of // events we want to request for the next pass. - for _, authEvent := range event.AuthEvents() { - if _, ok := authEventsMap[authEvent.EventID]; !ok { - eventsToFetch = append(eventsToFetch, authEvent.EventID) + for _, authEventID := range event.AuthEventIDs() { + if _, ok := authEventsMap[authEventID]; !ok { + eventsToFetch = append(eventsToFetch, authEventID) } } } @@ -691,7 +755,7 @@ func GetAuthChain( // We've now retrieved all of the events we can. Flatten them down into an // array and return them. - var authEvents []*gomatrixserverlib.Event + var authEvents []gomatrixserverlib.PDU for _, event := range authEventsMap { authEvents = append(authEvents, event) } @@ -700,34 +764,20 @@ func GetAuthChain( } // QueryRoomVersionForRoom implements api.RoomserverInternalAPI -func (r *Queryer) QueryRoomVersionForRoom( - ctx context.Context, - request *api.QueryRoomVersionForRoomRequest, - response *api.QueryRoomVersionForRoomResponse, -) error { - if roomVersion, ok := r.Cache.GetRoomVersion(request.RoomID); ok { - response.RoomVersion = roomVersion - return nil +func (r *Queryer) QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) { + if roomVersion, ok := r.Cache.GetRoomVersion(roomID); ok { + return roomVersion, nil } - info, err := r.DB.RoomInfo(ctx, request.RoomID) + info, err := r.DB.RoomInfo(ctx, roomID) if err != nil { - return err + return "", err } if info == nil { - return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID) + return "", fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", roomID) } - response.RoomVersion = info.RoomVersion - r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) - return nil -} - -func (r *Queryer) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) { - var res api.QueryRoomVersionForRoomResponse - err := r.QueryRoomVersionForRoom(context.Background(), &api.QueryRoomVersionForRoomRequest{ - RoomID: roomID, - }, &res) - return res.RoomVersion, err + r.Cache.StoreRoomVersion(roomID, info.RoomVersion) + return info.RoomVersion, nil } func (r *Queryer) QueryPublishedRooms( @@ -752,7 +802,7 @@ func (r *Queryer) QueryPublishedRooms( } func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { - res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) + res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent) for _, tuple := range req.StateTuples { if tuple.StateKey == "*" && req.AllowWildcards { events, err := r.DB.GetStateEventsWithEventType(ctx, req.RoomID, tuple.EventType) @@ -869,138 +919,120 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq if err != nil { return err } - hchain := make([]*gomatrixserverlib.HeaderedEvent, len(chain)) + hchain := make([]*types.HeaderedEvent, len(chain)) for i := range chain { - hchain[i] = chain[i].Headered(chain[i].Version()) + hchain[i] = &types.HeaderedEvent{PDU: chain[i]} } res.AuthChain = hchain return nil } +func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) { + pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), senderID) + return pending, err +} + +func (r *Queryer) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) { + return r.DB.RoomInfo(ctx, roomID.String()) +} + +func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { + res, err := r.DB.GetStateEvent(ctx, roomID.String(), eventType, stateKey) + if res == nil { + return nil, err + } + return res, err +} + +func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) { + _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID) + return isIn, err +} + +func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) { + joinNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true) + if err != nil { + return nil, err + } + + events, err := r.DB.Events(ctx, roomVersion, joinNIDs) + if err != nil { + return nil, err + } + + // For each of the joined users, let's see if we can get a valid + // membership event. + joinedUsers := []gomatrixserverlib.PDU{} + for _, event := range events { + if event.Type() != spec.MRoomMember || event.StateKey() == nil { + continue // shouldn't happen + } + + joinedUsers = append(joinedUsers, event) + } + + return joinedUsers, nil +} + // nolint:gocyclo -func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse) error { +func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { // Look up if we know anything about the room. If it doesn't exist // or is a stub entry then we can't do anything. - roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) + roomInfo, err := r.DB.RoomInfo(ctx, roomID.String()) if err != nil { - return fmt.Errorf("r.DB.RoomInfo: %w", err) + return "", fmt.Errorf("r.DB.RoomInfo: %w", err) } if roomInfo == nil || roomInfo.IsStub() { - return nil // fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID) + return "", nil // fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID) } - // If the room version doesn't allow restricted joins then don't - // try to process any further. - allowRestrictedJoins, err := roomInfo.RoomVersion.MayAllowRestrictedJoinsInEventAuth() + verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) if err != nil { - return fmt.Errorf("roomInfo.RoomVersion.AllowRestrictedJoinsInEventAuth: %w", err) - } else if !allowRestrictedJoins { - return nil + return "", err } - // Start off by populating the "resident" flag in the response. If we - // come across any rooms in the request that are missing, we will unset - // the flag. - res.Resident = true - // Get the join rules to work out if the join rule is "restricted". - joinRulesEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, gomatrixserverlib.MRoomJoinRules, "") + + return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) +} + +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + version, err := r.DB.GetRoomVersion(ctx, roomID.String()) if err != nil { - return fmt.Errorf("r.DB.GetStateEvent: %w", err) + return "", err } - if joinRulesEvent == nil { - return nil - } - var joinRules gomatrixserverlib.JoinRuleContent - if err = json.Unmarshal(joinRulesEvent.Content(), &joinRules); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - // If the join rule isn't "restricted" then there's nothing more to do. - res.Restricted = joinRules.JoinRule == gomatrixserverlib.Restricted - if !res.Restricted { - return nil + + switch version { + case gomatrixserverlib.RoomVersionPseudoIDs: + key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID) + if err != nil { + return "", err + } + return spec.SenderID(spec.Base64Bytes(key).Encode()), nil + default: + return spec.SenderID(userID.String()), nil } - // If the user is already invited to the room then the join is allowed - // but we don't specify an authorised via user, since the event auth - // will allow the join anyway. - var pending bool - if pending, _, _, _, err = helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID); err != nil { - return fmt.Errorf("helpers.IsInvitePending: %w", err) - } else if pending { - res.Allowed = true - return nil +} + +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + userID, err := spec.NewUserID(string(senderID), true) + if err == nil { + return userID, nil } - // We need to get the power levels content so that we can determine which - // users in the room are entitled to issue invites. We need to use one of - // these users as the authorising user. - powerLevelsEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, gomatrixserverlib.MRoomPowerLevels, "") + + bytes := spec.Base64Bytes{} + err = bytes.Decode(string(senderID)) if err != nil { - return fmt.Errorf("r.DB.GetStateEvent: %w", err) + return nil, err } - var powerLevels gomatrixserverlib.PowerLevelContent - if err = json.Unmarshal(powerLevelsEvent.Content(), &powerLevels); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) + queryMap := map[spec.RoomID][]ed25519.PublicKey{roomID: {ed25519.PublicKey(bytes)}} + result, err := r.DB.SelectUserIDsForPublicKeys(ctx, queryMap) + if err != nil { + return nil, err } - // Step through the join rules and see if the user matches any of them. - for _, rule := range joinRules.Allow { - // We only understand "m.room_membership" rules at this point in - // time, so skip any rule that doesn't match those. - if rule.Type != gomatrixserverlib.MRoomMembership { - continue - } - // See if the room exists. If it doesn't exist or if it's a stub - // room entry then we can't check memberships. - targetRoomInfo, err := r.DB.RoomInfo(ctx, rule.RoomID) - if err != nil || targetRoomInfo == nil || targetRoomInfo.IsStub() { - res.Resident = false - continue - } - // First of all work out if *we* are still in the room, otherwise - // it's possible that the memberships will be out of date. - isIn, err := r.DB.GetLocalServerInRoom(ctx, targetRoomInfo.RoomNID) - if err != nil || !isIn { - // If we aren't in the room, we can no longer tell if the room - // memberships are up-to-date. - res.Resident = false - continue - } - // At this point we're happy that we are in the room, so now let's - // see if the target user is in the room. - _, _, isIn, _, err = r.DB.GetMembership(ctx, targetRoomInfo.RoomNID, req.UserID) - if err != nil { - continue - } - // If the user is not in the room then we will skip them. - if !isIn { - continue - } - // The user is in the room, so now we will need to authorise the - // join using the user ID of one of our own users in the room. Pick - // one. - joinNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, targetRoomInfo.RoomNID, true, true) - if err != nil || len(joinNIDs) == 0 { - // There should always be more than one join NID at this point - // because we are gated behind GetLocalServerInRoom, but y'know, - // sometimes strange things happen. - continue - } - // For each of the joined users, let's see if we can get a valid - // membership event. - for _, joinNID := range joinNIDs { - events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID}) - if err != nil || len(events) != 1 { - continue - } - event := events[0] - if event.Type() != gomatrixserverlib.MRoomMember || event.StateKey() == nil { - continue // shouldn't happen - } - // Only users that have the power to invite should be chosen. - if powerLevels.UserLevel(*event.StateKey()) < powerLevels.Invite { - continue - } - res.Resident = true - res.Allowed = true - res.AuthorisedVia = *event.StateKey() - return nil + + if userKeys, ok := result[roomID]; ok { + if userID, ok := userKeys[string(senderID)]; ok { + return spec.NewUserID(userID, true) } } - return nil + + return nil, nil } diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index 265f326d4c..619d930306 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -18,32 +18,35 @@ import ( "context" "encoding/json" "testing" + "time" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // used to implement RoomserverInternalAPIEventDB to test getAuthChain type getEventDB struct { - eventMap map[string]*gomatrixserverlib.Event + eventMap map[string]gomatrixserverlib.PDU } func createEventDB() *getEventDB { return &getEventDB{ - eventMap: make(map[string]*gomatrixserverlib.Event), + eventMap: make(map[string]gomatrixserverlib.PDU), } } // Adds a fake event to the storage with given auth events. func (db *getEventDB) addFakeEvent(eventID string, authIDs []string) error { - authEvents := []gomatrixserverlib.EventReference{} + authEvents := make([]any, 0, len(authIDs)) for _, authID := range authIDs { - authEvents = append(authEvents, gomatrixserverlib.EventReference{ - EventID: authID, - }) + authEvents = append(authEvents, []any{authID, struct{}{}}) } - builder := map[string]interface{}{ "event_id": eventID, "auth_events": authEvents, @@ -54,8 +57,8 @@ func (db *getEventDB) addFakeEvent(eventID string, authIDs []string) error { return err } - event, err := gomatrixserverlib.NewEventFromTrustedJSON( - eventJSON, false, gomatrixserverlib.RoomVersionV1, + event, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON( + eventJSON, false, ) if err != nil { return err @@ -84,7 +87,7 @@ func (db *getEventDB) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInf for _, evID := range eventIDs { res = append(res, types.Event{ EventNID: 0, - Event: db.eventMap[evID], + PDU: db.eventMap[evID], }) } @@ -155,3 +158,30 @@ func TestGetAuthChainMultiple(t *testing.T) { t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs) } } + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.Open(context.Background(), cm, &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}, caches) + if err != nil { + t.Fatalf("failed to create Database: %v", err) + } + return db, close +} + +func TestCurrentEventIsNil(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + querier := Queryer{ + DB: db, + } + + roomID, _ := spec.NewRoomID("!room:server") + event, _ := querier.CurrentStateEvent(context.Background(), *roomID, spec.MRoomMember, "@user:server") + if event != nil { + t.Fatal("Event should equal nil, most likely this is failing because the interface type is not nil, but the value is.") + } + }) +} diff --git a/roomserver/producers/roomevent.go b/roomserver/producers/roomevent.go index 9c45219866..165304d49c 100644 --- a/roomserver/producers/roomevent.go +++ b/roomserver/producers/roomevent.go @@ -60,7 +60,7 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu "adds_state": len(update.NewRoomEvent.AddsStateEventIDs), "removes_state": len(update.NewRoomEvent.RemovesStateEventIDs), "send_as_server": update.NewRoomEvent.SendAsServer, - "sender": update.NewRoomEvent.Event.Sender(), + "sender": update.NewRoomEvent.Event.SenderID(), }) if update.NewRoomEvent.Event.StateKey() != nil { logger = logger.WithField("state_key", *update.NewRoomEvent.Event.StateKey()) @@ -74,7 +74,7 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu } if eventType == "m.room.server_acl" && update.NewRoomEvent.Event.StateKeyEquals("") { - ev := update.NewRoomEvent.Event.Unwrap() + ev := update.NewRoomEvent.Event.PDU defer r.ACLs.OnServerACLUpdate(ev) } } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 729da15b3f..76b21ad23f 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -8,9 +8,13 @@ import ( "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/version" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" @@ -31,6 +35,14 @@ import ( "github.com/matrix-org/dendrite/test/testrig" ) +type FakeQuerier struct { + api.QuerySenderIDAPI +} + +func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + func TestUsers(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { cfg, processCtx, close := testrig.CreateConfig(t, dbType) @@ -61,10 +73,10 @@ func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) { room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) // Invite and join Bob - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) @@ -102,7 +114,7 @@ func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true)) // Join with the guest user - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) @@ -134,8 +146,8 @@ func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI } // revoke guest access - revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil { + revokeEvent := room.CreateAndInsert(t, alice, spec.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*types.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil { t.Errorf("failed to send events: %v", err) } @@ -164,10 +176,10 @@ func Test_QueryLeftUsers(t *testing.T) { room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) // Invite and join Bob - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) @@ -216,7 +228,7 @@ func TestPurgeRoom(t *testing.T) { room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) // Invite Bob - inviteEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + inviteEvent := room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) @@ -241,8 +253,8 @@ func TestPurgeRoom(t *testing.T) { // this starts the JetStream consumers syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) - federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) - rsAPI.SetFederationAPI(nil, nil) + fsAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) + rsAPI.SetFederationAPI(fsAPI, nil) // Create the room if err = api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { @@ -250,13 +262,9 @@ func TestPurgeRoom(t *testing.T) { } // some dummy entries to validate after purging - publishResp := &api.PerformPublishResponse{} - if err = rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil { + if err = rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: spec.Public}); err != nil { t.Fatal(err) } - if publishResp.Error != nil { - t.Fatal(publishResp.Error) - } isPublished, err := db.GetPublishedRoom(ctx, room.ID) if err != nil { @@ -324,8 +332,7 @@ func TestPurgeRoom(t *testing.T) { } // purge the room from the database - purgeResp := &api.PerformAdminPurgeRoomResponse{} - if err = rsAPI.PerformAdminPurgeRoom(ctx, &api.PerformAdminPurgeRoomRequest{RoomID: room.ID}, purgeResp); err != nil { + if err = rsAPI.PerformAdminPurgeRoom(ctx, room.ID); err != nil { t.Fatal(err) } @@ -393,36 +400,37 @@ func TestPurgeRoom(t *testing.T) { type fledglingEvent struct { Type string StateKey *string - Sender string + SenderID string RoomID string Redacts string Depth int64 PrevEvents []interface{} } -func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { +func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEvent) { t.Helper() roomVer := gomatrixserverlib.RoomVersionV9 seed := make([]byte, ed25519.SeedSize) // zero seed key := ed25519.NewKeyFromSeed(seed) - eb := gomatrixserverlib.EventBuilder{ - Sender: ev.Sender, + eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ + SenderID: ev.SenderID, Type: ev.Type, StateKey: ev.StateKey, RoomID: ev.RoomID, Redacts: ev.Redacts, Depth: ev.Depth, PrevEvents: ev.PrevEvents, - } + }) err := eb.SetContent(map[string]interface{}{}) if err != nil { t.Fatalf("mustCreateEvent: failed to marshal event content %v", err) } - signedEvent, err := eb.Build(time.Now(), "localhost", "ed25519:test", key, roomVer) + + signedEvent, err := eb.Build(time.Now(), "localhost", "ed25519:test", key) if err != nil { t.Fatalf("mustCreateEvent: failed to sign event: %s", err) } - h := signedEvent.Headered(roomVer) + h := &types.HeaderedEvent{PDU: signedEvent} return h } @@ -443,14 +451,14 @@ func TestRedaction(t *testing.T) { redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hello world"}) builderEv := mustCreateEvent(t, fledglingEvent{ - Type: gomatrixserverlib.MRoomRedaction, - Sender: alice.ID, + Type: spec.MRoomRedaction, + SenderID: alice.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), Depth: redactedEvent.Depth() + 1, PrevEvents: []interface{}{redactedEvent.EventID()}, }) - room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) + room.InsertEvent(t, builderEv) }, }, { @@ -460,14 +468,14 @@ func TestRedaction(t *testing.T) { redactedEvent := room.CreateAndInsert(t, bob, "m.room.message", map[string]interface{}{"body": "hello world"}) builderEv := mustCreateEvent(t, fledglingEvent{ - Type: gomatrixserverlib.MRoomRedaction, - Sender: alice.ID, + Type: spec.MRoomRedaction, + SenderID: alice.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), Depth: redactedEvent.Depth() + 1, PrevEvents: []interface{}{redactedEvent.EventID()}, }) - room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) + room.InsertEvent(t, builderEv) }, }, { @@ -477,14 +485,14 @@ func TestRedaction(t *testing.T) { redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hello world"}) builderEv := mustCreateEvent(t, fledglingEvent{ - Type: gomatrixserverlib.MRoomRedaction, - Sender: bob.ID, + Type: spec.MRoomRedaction, + SenderID: bob.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), Depth: redactedEvent.Depth() + 1, PrevEvents: []interface{}{redactedEvent.EventID()}, }) - room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) + room.InsertEvent(t, builderEv) }, }, { @@ -493,14 +501,14 @@ func TestRedaction(t *testing.T) { redactedEvent := room.CreateAndInsert(t, bob, "m.room.message", map[string]interface{}{"body": "hello world"}) builderEv := mustCreateEvent(t, fledglingEvent{ - Type: gomatrixserverlib.MRoomRedaction, - Sender: charlie.ID, + Type: spec.MRoomRedaction, + SenderID: charlie.ID, RoomID: room.ID, Redacts: redactedEvent.EventID(), Depth: redactedEvent.Depth() + 1, PrevEvents: []interface{}{redactedEvent.EventID()}, }) - room.InsertEvent(t, builderEv.Headered(gomatrixserverlib.RoomVersionV9)) + room.InsertEvent(t, builderEv) }, }, } @@ -516,6 +524,9 @@ func TestRedaction(t *testing.T) { t.Fatal(err) } + natsInstance := &jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { authEvents := []types.EventNID{} @@ -523,10 +534,10 @@ func TestRedaction(t *testing.T) { var err error room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(charlie.ID)) @@ -535,7 +546,7 @@ func TestRedaction(t *testing.T) { } for _, ev := range room.Events() { - roomInfo, err = db.GetOrCreateRoomInfo(ctx, ev.Event) + roomInfo, err = db.GetOrCreateRoomInfo(ctx, ev.PDU) assert.NoError(t, err) assert.NotNil(t, roomInfo) evTypeNID, err := db.GetOrCreateEventTypeNID(ctx, ev.Type()) @@ -544,15 +555,15 @@ func TestRedaction(t *testing.T) { stateKeyNID, err := db.GetOrCreateEventStateKeyNID(ctx, ev.StateKey()) assert.NoError(t, err) - eventNID, stateAtEvent, err := db.StoreEvent(ctx, ev.Event, roomInfo, evTypeNID, stateKeyNID, authEvents, false) + eventNID, stateAtEvent, err := db.StoreEvent(ctx, ev.PDU, roomInfo, evTypeNID, stateKeyNID, authEvents, false) assert.NoError(t, err) if ev.StateKey() != nil { authEvents = append(authEvents, eventNID) } // Calculate the snapshotNID etc. - plResolver := state.NewStateResolution(db, roomInfo) - stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.Event, false) + plResolver := state.NewStateResolution(db, roomInfo, rsAPI) + stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.PDU, false) assert.NoError(t, err) // Update the room @@ -563,15 +574,15 @@ func TestRedaction(t *testing.T) { err = updater.Commit() assert.NoError(t, err) - _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Event, &plResolver) + _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.PDU, &plResolver, &FakeQuerier{}) assert.NoError(t, err) if redactedEvent != nil { assert.Equal(t, ev.Redacts(), redactedEvent.EventID()) } - if ev.Type() == gomatrixserverlib.MRoomRedaction { + if ev.Type() == spec.MRoomRedaction { nids, err := db.EventNIDs(ctx, []string{ev.Redacts()}) assert.NoError(t, err) - evs, err := db.Events(ctx, roomInfo, []types.EventNID{nids[ev.Redacts()].EventNID}) + evs, err := db.Events(ctx, roomInfo.RoomVersion, []types.EventNID{nids[ev.Redacts()].EventNID}) assert.NoError(t, err) assert.Equal(t, 1, len(evs)) assert.Equal(t, tc.wantRedacted, evs[0].Redacted()) @@ -581,3 +592,482 @@ func TestRedaction(t *testing.T) { } }) } + +func TestQueryRestrictedJoinAllowed(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + + // a room we don't create in the database + allowedByRoomNotExists := test.NewRoom(t, alice) + + // a room we create in the database, used for authorisation + allowedByRoomExists := test.NewRoom(t, alice) + allowedByRoomExists.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ + "membership": spec.Join, + }, test.WithStateKey(bob.ID)) + + testCases := []struct { + name string + prepareRoomFunc func(t *testing.T) *test.Room + wantResponse string + wantError bool + }{ + { + name: "public room unrestricted", + prepareRoomFunc: func(t *testing.T) *test.Room { + return test.NewRoom(t, alice) + }, + wantResponse: "", + }, + { + name: "room version without restrictions", + prepareRoomFunc: func(t *testing.T) *test.Room { + return test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV7)) + }, + }, + { + name: "restricted only", // bob is not allowed to join + prepareRoomFunc: func(t *testing.T) *test.Room { + r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV8)) + r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{ + "join_rule": spec.Restricted, + }, test.WithStateKey("")) + return r + }, + wantError: true, + }, + { + name: "knock_restricted", + prepareRoomFunc: func(t *testing.T) *test.Room { + r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV8)) + r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{ + "join_rule": spec.KnockRestricted, + }, test.WithStateKey("")) + return r + }, + wantError: true, + }, + { + name: "restricted with pending invite", // bob should be allowed to join + prepareRoomFunc: func(t *testing.T) *test.Room { + r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV8)) + r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{ + "join_rule": spec.Restricted, + }, test.WithStateKey("")) + r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ + "membership": spec.Invite, + }, test.WithStateKey(bob.ID)) + return r + }, + wantResponse: "", + }, + { + name: "restricted with allowed room_id, but missing room", // bob should not be allowed to join, as we don't know about the room + prepareRoomFunc: func(t *testing.T) *test.Room { + r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV10)) + r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{ + "join_rule": spec.KnockRestricted, + "allow": []map[string]interface{}{ + { + "room_id": allowedByRoomNotExists.ID, + "type": spec.MRoomMembership, + }, + }, + }, test.WithStateKey("")) + r.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ + "membership": spec.Join, + "join_authorised_via_users_server": alice.ID, + }, test.WithStateKey(bob.ID)) + return r + }, + wantError: true, + }, + { + name: "restricted with allowed room_id", // bob should be allowed to join, as we know about the room + prepareRoomFunc: func(t *testing.T) *test.Room { + r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV10)) + r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{ + "join_rule": spec.KnockRestricted, + "allow": []map[string]interface{}{ + { + "room_id": allowedByRoomExists.ID, + "type": spec.MRoomMembership, + }, + }, + }, test.WithStateKey("")) + r.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ + "membership": spec.Join, + "join_authorised_via_users_server": alice.ID, + }, test.WithStateKey(bob.ID)) + return r + }, + wantResponse: alice.ID, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + natsInstance := jetstream.NATSInstance{} + defer close() + + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.prepareRoomFunc == nil { + t.Fatal("missing prepareRoomFunc") + } + testRoom := tc.prepareRoomFunc(t) + // Create the room + if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, testRoom.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, allowedByRoomExists.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + roomID, _ := spec.NewRoomID(testRoom.ID) + userID, _ := spec.NewUserID(bob.ID, true) + got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, spec.SenderID(userID.String())) + if tc.wantError && err == nil { + t.Fatal("expected error, got none") + } + if !tc.wantError && err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tc.wantResponse, got) { + t.Fatalf("unexpected response, want %#v - got %#v", tc.wantResponse, got) + } + }) + } + }) +} + +func TestUpgrade(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t) + ctx := context.Background() + + spaceChild := test.NewRoom(t, alice) + validateTuples := []gomatrixserverlib.StateKeyTuple{ + {EventType: spec.MRoomCreate}, + {EventType: spec.MRoomPowerLevels}, + {EventType: spec.MRoomJoinRules}, + {EventType: spec.MRoomName}, + {EventType: spec.MRoomCanonicalAlias}, + {EventType: "m.room.tombstone"}, + {EventType: "m.custom.event"}, + {EventType: "m.space.child", StateKey: spaceChild.ID}, + {EventType: "m.custom.event", StateKey: alice.ID}, + {EventType: spec.MRoomMember, StateKey: charlie.ID}, // ban should be transferred + } + + validate := func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) { + + oldRoomState := &api.QueryCurrentStateResponse{} + if err := rsAPI.QueryCurrentState(ctx, &api.QueryCurrentStateRequest{ + RoomID: oldRoomID, + StateTuples: validateTuples, + }, oldRoomState); err != nil { + t.Fatal(err) + } + + newRoomState := &api.QueryCurrentStateResponse{} + if err := rsAPI.QueryCurrentState(ctx, &api.QueryCurrentStateRequest{ + RoomID: newRoomID, + StateTuples: validateTuples, + }, newRoomState); err != nil { + t.Fatal(err) + } + + // the old room should have a tombstone event + ev := oldRoomState.StateEvents[gomatrixserverlib.StateKeyTuple{EventType: "m.room.tombstone"}] + replacementRoom := gjson.GetBytes(ev.Content(), "replacement_room").Str + if replacementRoom != newRoomID { + t.Fatalf("tombstone event has replacement_room '%s', expected '%s'", replacementRoom, newRoomID) + } + + // the new room should have a predecessor equal to the old room + ev = newRoomState.StateEvents[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate}] + predecessor := gjson.GetBytes(ev.Content(), "predecessor.room_id").Str + if predecessor != oldRoomID { + t.Fatalf("got predecessor room '%s', expected '%s'", predecessor, oldRoomID) + } + + for _, tuple := range validateTuples { + // Skip create and powerlevel event (new room has e.g. predecessor event, old room has restricted powerlevels) + switch tuple.EventType { + case spec.MRoomCreate, spec.MRoomPowerLevels, spec.MRoomCanonicalAlias: + continue + } + oldEv, ok := oldRoomState.StateEvents[tuple] + if !ok { + t.Logf("skipping tuple %#v as it doesn't exist in the old room", tuple) + continue + } + newEv, ok := newRoomState.StateEvents[tuple] + if !ok { + t.Logf("skipping tuple %#v as it doesn't exist in the new room", tuple) + continue + } + + if !reflect.DeepEqual(oldEv.Content(), newEv.Content()) { + t.Logf("OldEvent QueryCurrentState: %s", string(oldEv.Content())) + t.Logf("NewEvent QueryCurrentState: %s", string(newEv.Content())) + t.Errorf("event content mismatch") + } + } + } + + testCases := []struct { + name string + upgradeUser string + roomFunc func(rsAPI api.RoomserverInternalAPI) string + validateFunc func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) + wantNewRoom bool + }{ + { + name: "invalid roomID", + upgradeUser: alice.ID, + roomFunc: func(rsAPI api.RoomserverInternalAPI) string { + return "!doesnotexist:test" + }, + }, + { + name: "powerlevel too low", + upgradeUser: bob.ID, + roomFunc: func(rsAPI api.RoomserverInternalAPI) string { + room := test.NewRoom(t, alice) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + return room.ID + }, + }, + { + name: "successful upgrade on new room", + upgradeUser: alice.ID, + roomFunc: func(rsAPI api.RoomserverInternalAPI) string { + room := test.NewRoom(t, alice) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + return room.ID + }, + wantNewRoom: true, + validateFunc: validate, + }, + { + name: "successful upgrade on new room with other state events", + upgradeUser: alice.ID, + roomFunc: func(rsAPI api.RoomserverInternalAPI) string { + r := test.NewRoom(t, alice) + r.CreateAndInsert(t, alice, spec.MRoomName, map[string]interface{}{ + "name": "my new name", + }, test.WithStateKey("")) + r.CreateAndInsert(t, alice, spec.MRoomCanonicalAlias, eventutil.CanonicalAliasContent{ + Alias: "#myalias:test", + }, test.WithStateKey("")) + + // this will be transferred + r.CreateAndInsert(t, alice, "m.custom.event", map[string]interface{}{ + "random": "i should exist", + }, test.WithStateKey("")) + + // the following will be ignored + r.CreateAndInsert(t, alice, "m.custom.event", map[string]interface{}{ + "random": "i will be ignored", + }, test.WithStateKey(alice.ID)) + + if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + return r.ID + }, + wantNewRoom: true, + validateFunc: validate, + }, + { + name: "with published room", + upgradeUser: alice.ID, + roomFunc: func(rsAPI api.RoomserverInternalAPI) string { + r := test.NewRoom(t, alice) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{ + RoomID: r.ID, + Visibility: spec.Public, + }); err != nil { + t.Fatal(err) + } + + return r.ID + }, + wantNewRoom: true, + validateFunc: func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) { + validate(t, oldRoomID, newRoomID, rsAPI) + // check that the new room is published + res := &api.QueryPublishedRoomsResponse{} + if err := rsAPI.QueryPublishedRooms(ctx, &api.QueryPublishedRoomsRequest{RoomID: newRoomID}, res); err != nil { + t.Fatal(err) + } + if len(res.RoomIDs) == 0 { + t.Fatalf("expected room to be published, but wasn't: %#v", res.RoomIDs) + } + }, + }, + { + name: "with alias", + upgradeUser: alice.ID, + roomFunc: func(rsAPI api.RoomserverInternalAPI) string { + r := test.NewRoom(t, alice) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + if err := rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{ + RoomID: r.ID, + Alias: "#myroomalias:test", + }, &api.SetRoomAliasResponse{}); err != nil { + t.Fatal(err) + } + + return r.ID + }, + wantNewRoom: true, + validateFunc: func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) { + validate(t, oldRoomID, newRoomID, rsAPI) + // check that the old room has no aliases + res := &api.GetAliasesForRoomIDResponse{} + if err := rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: oldRoomID}, res); err != nil { + t.Fatal(err) + } + if len(res.Aliases) != 0 { + t.Fatalf("expected old room aliases to be empty, but wasn't: %#v", res.Aliases) + } + + // check that the new room has aliases + if err := rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: newRoomID}, res); err != nil { + t.Fatal(err) + } + if len(res.Aliases) == 0 { + t.Fatalf("expected room aliases to be transferred, but wasn't: %#v", res.Aliases) + } + }, + }, + { + name: "bans are transferred", + upgradeUser: alice.ID, + roomFunc: func(rsAPI api.RoomserverInternalAPI) string { + r := test.NewRoom(t, alice) + r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ + "membership": spec.Ban, + }, test.WithStateKey(charlie.ID)) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + return r.ID + }, + wantNewRoom: true, + validateFunc: validate, + }, + { + name: "space childs are transferred", + upgradeUser: alice.ID, + roomFunc: func(rsAPI api.RoomserverInternalAPI) string { + r := test.NewRoom(t, alice) + + r.CreateAndInsert(t, alice, "m.space.child", map[string]interface{}{}, test.WithStateKey(spaceChild.ID)) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + return r.ID + }, + wantNewRoom: true, + validateFunc: validate, + }, + { + name: "custom state is not taken to the new room", // https://github.com/matrix-org/dendrite/issues/2912 + upgradeUser: charlie.ID, + roomFunc: func(rsAPI api.RoomserverInternalAPI) string { + r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV6)) + // Bob and Charlie join + r.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{"membership": spec.Join}, test.WithStateKey(bob.ID)) + r.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]interface{}{"membership": spec.Join}, test.WithStateKey(charlie.ID)) + + // make Charlie an admin so the room can be upgraded + r.CreateAndInsert(t, alice, spec.MRoomPowerLevels, gomatrixserverlib.PowerLevelContent{ + Users: map[string]int64{ + charlie.ID: 100, + }, + }, test.WithStateKey("")) + + // Alice creates a custom event + r.CreateAndInsert(t, alice, "m.custom.event", map[string]interface{}{ + "random": "data", + }, test.WithStateKey(alice.ID)) + r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{"membership": spec.Leave}, test.WithStateKey(alice.ID)) + + if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + return r.ID + }, + wantNewRoom: true, + validateFunc: validate, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + natsInstance := jetstream.NATSInstance{} + defer close() + + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + rsAPI.SetFederationAPI(nil, nil) + rsAPI.SetUserAPI(userAPI) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.roomFunc == nil { + t.Fatalf("missing roomFunc") + } + if tc.upgradeUser == "" { + tc.upgradeUser = alice.ID + } + roomID := tc.roomFunc(rsAPI) + + userID, err := spec.NewUserID(tc.upgradeUser, true) + if err != nil { + t.Fatalf("upgrade userID is invalid") + } + newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, *userID, version.DefaultRoomVersion()) + if err != nil && tc.wantNewRoom { + t.Fatal(err) + } + + if tc.wantNewRoom && newRoomID == "" { + t.Fatalf("expected a new room, but the upgrade failed") + } + if !tc.wantNewRoom && newRoomID != "" { + t.Fatalf("expected no new room, but the upgrade succeeded") + } + if tc.validateFunc != nil { + tc.validateFunc(t, roomID, newRoomID, rsAPI) + } + }) + } + }) +} diff --git a/roomserver/state/state.go b/roomserver/state/state.go index c3842784eb..1e776ff6c8 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -24,10 +24,12 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -41,21 +43,23 @@ type StateResolutionStorage interface { StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) - Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) } type StateResolution struct { db StateResolutionStorage roomInfo *types.RoomInfo - events map[types.EventNID]*gomatrixserverlib.Event + events map[types.EventNID]gomatrixserverlib.PDU + Querier api.QuerySenderIDAPI } -func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution { +func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo, querier api.QuerySenderIDAPI) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo, - events: make(map[types.EventNID]*gomatrixserverlib.Event), + events: make(map[types.EventNID]gomatrixserverlib.PDU), + Querier: querier, } } @@ -85,7 +89,10 @@ func (p *StateResolution) Resolve(ctx context.Context, eventID string) (*gomatri return nil, fmt.Errorf("unable to find power level event") } - events, err := p.db.Events(ctx, p.roomInfo, []types.EventNID{plNID}) + if p.roomInfo == nil { + return nil, types.ErrorInvalidRoomInfo + } + events, err := p.db.Events(ctx, p.roomInfo.RoomVersion, []types.EventNID{plNID}) if err != nil { return nil, err } @@ -702,7 +709,7 @@ func init() { // Returns a numeric ID for the snapshot of the state before the event. func (v *StateResolution) CalculateAndStoreStateBeforeEvent( ctx context.Context, - event *gomatrixserverlib.Event, + event gomatrixserverlib.PDU, isRejected bool, ) (types.StateSnapshotNID, error) { trace, ctx := internal.StartRegion(ctx, "StateResolution.CalculateAndStoreStateBeforeEvent") @@ -878,10 +885,12 @@ func (v *StateResolution) resolveConflicts( trace, ctx := internal.StartRegion(ctx, "StateResolution.resolveConflicts") defer trace.EndRegion() - stateResAlgo, err := version.StateResAlgorithm() + verImpl, err := gomatrixserverlib.GetRoomVersion(version) if err != nil { return nil, err } + + stateResAlgo := verImpl.StateResAlgorithm() switch stateResAlgo { case gomatrixserverlib.StateResV1: return v.resolveConflictsV1(ctx, notConflicted, conflicted) @@ -940,7 +949,9 @@ func (v *StateResolution) resolveConflictsV1( } // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) + }) // Map from the full events back to numeric state entries. for _, resolvedEvent := range resolvedEvents { @@ -993,8 +1004,8 @@ func (v *StateResolution) resolveConflictsV2( // For each conflicted event, we will add a new set of auth events. Auth // events may be duplicated across these sets but that's OK. - authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted)) - authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) + authSets := make(map[string][]gomatrixserverlib.PDU, len(conflicted)) + authEvents := make([]gomatrixserverlib.PDU, 0, estimate*3) gotAuthEvents := make(map[string]struct{}, estimate*3) knownAuthEvents := make(map[string]types.Event, estimate*3) @@ -1044,7 +1055,7 @@ func (v *StateResolution) resolveConflictsV2( gotAuthEvents = nil // nolint:ineffassign // Resolve the conflicts. - resolvedEvents := func() []*gomatrixserverlib.Event { + resolvedEvents := func() []gomatrixserverlib.PDU { resolvedTrace, _ := internal.StartRegion(ctx, "StateResolution.ResolveStateConflictsV2") defer resolvedTrace.EndRegion() @@ -1052,6 +1063,9 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, + func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) + }, ) }() @@ -1117,11 +1131,11 @@ func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.E // Returns an error if there was a problem talking to the database. func (v *StateResolution) loadStateEvents( ctx context.Context, entries []types.StateEntry, -) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { +) ([]gomatrixserverlib.PDU, map[string]types.StateEntry, error) { trace, ctx := internal.StartRegion(ctx, "StateResolution.loadStateEvents") defer trace.EndRegion() - result := make([]*gomatrixserverlib.Event, 0, len(entries)) + result := make([]gomatrixserverlib.PDU, 0, len(entries)) eventEntries := make([]types.StateEntry, 0, len(entries)) eventNIDs := make(types.EventNIDs, 0, len(entries)) for _, entry := range entries { @@ -1132,7 +1146,11 @@ func (v *StateResolution) loadStateEvents( eventNIDs = append(eventNIDs, entry.EventNID) } } - events, err := v.db.Events(ctx, v.roomInfo, eventNIDs) + + if v.roomInfo == nil { + return nil, nil, types.ErrorInvalidRoomInfo + } + events, err := v.db.Events(ctx, v.roomInfo.RoomVersion, eventNIDs) if err != nil { return nil, nil, err } @@ -1142,9 +1160,9 @@ func (v *StateResolution) loadStateEvents( if !ok { panic(fmt.Errorf("corrupt DB: Missing event numeric ID %d", entry.EventNID)) } - result = append(result, event.Event) - eventIDMap[event.Event.EventID()] = entry - v.events[entry.EventNID] = event.Event + result = append(result, event.PDU) + eventIDMap[event.PDU.EventID()] = entry + v.events[entry.EventNID] = event.PDU } return result, eventIDMap, nil } @@ -1161,8 +1179,8 @@ type authEventLoader struct { // loadAuthEvents loads all of the auth events for a given event recursively, // along with a map that contains state entries for all of the auth events. func (l *authEventLoader) loadAuthEvents( - ctx context.Context, roomInfo *types.RoomInfo, event *gomatrixserverlib.Event, eventMap map[string]types.Event, -) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { + ctx context.Context, roomInfo *types.RoomInfo, event gomatrixserverlib.PDU, eventMap map[string]types.Event, +) ([]gomatrixserverlib.PDU, map[string]types.StateEntry, error) { l.Lock() defer l.Unlock() authEvents := []types.Event{} // our returned list @@ -1263,9 +1281,9 @@ func (l *authEventLoader) loadAuthEvents( }, } } - nakedEvents := make([]*gomatrixserverlib.Event, 0, len(authEvents)) + nakedEvents := make([]gomatrixserverlib.PDU, 0, len(authEvents)) for _, authEvent := range authEvents { - nakedEvents = append(nakedEvents, authEvent.Event) + nakedEvents = append(nakedEvents, authEvent.PDU) } return nakedEvents, stateEntryMap, nil } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index a577f46505..e9b4609ecc 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -16,8 +16,11 @@ package storage import ( "context" + "crypto/ed25519" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -26,8 +29,10 @@ import ( ) type Database interface { + UserRoomKeys // Do we support processing input events for more than one room at a time? SupportsConcurrentRoomInputs() bool + AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) @@ -40,7 +45,7 @@ type Database interface { ) (types.StateSnapshotNID, error) MissingAuthPrevEvents( - ctx context.Context, e *gomatrixserverlib.Event, + ctx context.Context, e gomatrixserverlib.PDU, ) (missingAuth, missingPrev []string, err error) // Look up the state of a room at each event for a list of string event IDs. @@ -71,12 +76,12 @@ type Database interface { ) ([]types.StateEntryList, error) // Look up the Events for a list of numeric event IDs. // Returns a sorted list of events. - Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error. - StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) + StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database // Returns a types.MissingEventError if the event IDs aren't in the database. @@ -101,7 +106,7 @@ type Database interface { // Look up event references for the latest events in the room and the current state snapshot. // Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns an error if there was a problem talking to the database. - LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]string, types.StateSnapshotNID, int64, error) // Look up the active invites targeting a user in a room and return the // numeric state key IDs for the user IDs who sent them along with the event IDs for the invites. // Returns an error if there was a problem talking to the database. @@ -128,7 +133,7 @@ type Database interface { // in this room, along a boolean set to true if the user is still in this room, // false if not. // Returns an error if there was a problem talking to the database. - GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, membershipNID tables.MembershipState, stillInRoom, isRoomForgotten bool, err error) + GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) // Lookup the membership event numeric IDs for all user that are or have // been members of a given room. Only lookup events of "join" membership if // joinOnly is set to true. @@ -138,7 +143,7 @@ type Database interface { // not found. // Returns an error if the retrieval went wrong. EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) - // Publish or unpublish a room from the room directory. + // PerformPublish publishes or unpublishes a room from the room directory. Returns a database error, if any. PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error // Returns a list of room IDs for rooms which are published. GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error) @@ -150,8 +155,8 @@ type Database interface { // GetStateEvent returns the state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error - GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) - GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) + GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*types.HeaderedEvent, error) // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. @@ -162,7 +167,7 @@ type Database interface { // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) // GetServerInRoom returns true if we think a server is in a given room or false otherwise. - GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) + GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) // GetKnownRooms returns a list of all rooms we know about. @@ -170,7 +175,7 @@ type Database interface { // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error - GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) + GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]gomatrixserverlib.PDU, error) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) PurgeRoom(ctx context.Context, roomID string) error UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error @@ -180,23 +185,43 @@ type Database interface { // a membership of "leave" when calculating history visibility. GetMembershipForHistoryVisibility( ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, - ) (map[string]*gomatrixserverlib.HeaderedEvent, error) - GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) + ) (map[string]*types.HeaderedEvent, error) + GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error) + GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) MaybeRedactEvent( - ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, plResolver state.PowerLevelResolver, - ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI, + ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) +} + +type UserRoomKeys interface { + // InsertUserRoomPrivatePublicKey inserts the given private key as well as the public key for it. This should be used + // when creating keys locally. + InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) + // InsertUserRoomPublicKey inserts the given public key, this should be used for users NOT local to this server + InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) + // SelectUserRoomPrivateKey selects the private key for the given user and room combination + SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) + // SelectUserRoomPublicKey selects the public key for the given user and room combination + SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) + // SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID. + // If a senderKey can't be found, it is omitted in the result. + // TODO: Why is the result map indexed by string not public key? + // TODO: Shouldn't the input & result map be changed to be indexed by string instead of the RoomID struct? + SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error) } type RoomDatabase interface { EventDatabase + UserRoomKeys + AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) // IsEventRejected returns true if the event is known and rejected. IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error) - MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error) + MissingAuthPrevEvents(ctx context.Context, e gomatrixserverlib.PDU) (missingAuth, missingPrev []string, err error) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) @@ -205,11 +230,11 @@ type RoomDatabase interface { BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) - LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) - GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) + LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]string, types.StateSnapshotNID, int64, error) + GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) - GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) } type EventDatabase interface { @@ -223,11 +248,11 @@ type EventDatabase interface { SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) - Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) // MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error // (nil if there was nothing to do) MaybeRedactEvent( - ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, plResolver state.PowerLevelResolver, - ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) - StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI, + ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) + StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) } diff --git a/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go b/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go new file mode 100644 index 0000000000..1b1dd44d3c --- /dev/null +++ b/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha.go @@ -0,0 +1,120 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/util" +) + +func UpDropEventReferenceSHAEvents(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_events DROP COLUMN IF EXISTS reference_sha256;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func UpDropEventReferenceSHAPrevEvents(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE roomserver_previous_events DROP CONSTRAINT roomserver_previous_event_id_unique;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_previous_events DROP COLUMN IF EXISTS previous_reference_sha256;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + + // figure out if there are duplicates + dupeRows, err := tx.QueryContext(ctx, `SELECT previous_event_id FROM roomserver_previous_events GROUP BY previous_event_id HAVING count(previous_event_id) > 1`) + if err != nil { + return fmt.Errorf("failed to query duplicate event ids") + } + defer internal.CloseAndLogIfError(ctx, dupeRows, "failed to close rows") + + var prevEvents []string + var prevEventID string + for dupeRows.Next() { + if err = dupeRows.Scan(&prevEventID); err != nil { + return err + } + prevEvents = append(prevEvents, prevEventID) + } + if dupeRows.Err() != nil { + return dupeRows.Err() + } + + // if we found duplicates, check if we can combine them, e.g. they are in the same room + for _, dupeID := range prevEvents { + var dupeNIDsRows *sql.Rows + dupeNIDsRows, err = tx.QueryContext(ctx, `SELECT event_nids FROM roomserver_previous_events WHERE previous_event_id = $1`, dupeID) + if err != nil { + return fmt.Errorf("failed to query duplicate event ids") + } + defer internal.CloseAndLogIfError(ctx, dupeNIDsRows, "failed to close rows") + var dupeNIDs []int64 + for dupeNIDsRows.Next() { + var nids pq.Int64Array + if err = dupeNIDsRows.Scan(&nids); err != nil { + return err + } + dupeNIDs = append(dupeNIDs, nids...) + } + + if dupeNIDsRows.Err() != nil { + return dupeNIDsRows.Err() + } + // dedupe NIDs + dupeNIDs = dupeNIDs[:util.SortAndUnique(nids(dupeNIDs))] + // now that we have all NIDs, check which room they belong to + var roomCount int + err = tx.QueryRowContext(ctx, `SELECT count(distinct room_nid) FROM roomserver_events WHERE event_nid = ANY($1)`, pq.Array(dupeNIDs)).Scan(&roomCount) + if err != nil { + return err + } + // if the events are from different rooms, that's bad and we can't continue + if roomCount > 1 { + return fmt.Errorf("detected events (%v) referenced for different rooms (%v)", dupeNIDs, roomCount) + } + // otherwise delete the dupes + _, err = tx.ExecContext(ctx, "DELETE FROM roomserver_previous_events WHERE previous_event_id = $1", dupeID) + if err != nil { + return fmt.Errorf("unable to delete duplicates: %w", err) + } + + // insert combined values + _, err = tx.ExecContext(ctx, "INSERT INTO roomserver_previous_events (previous_event_id, event_nids) VALUES ($1, $2)", dupeID, pq.Array(dupeNIDs)) + if err != nil { + return fmt.Errorf("unable to insert new event NIDs: %w", err) + } + } + + _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_previous_events ADD CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id);`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +type nids []int64 + +func (s nids) Len() int { return len(s) } +func (s nids) Less(i, j int) bool { return s[i] < s[j] } +func (s nids) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha_test.go b/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha_test.go new file mode 100644 index 0000000000..c79daac5fc --- /dev/null +++ b/roomserver/storage/postgres/deltas/20230516154000_drop_reference_sha_test.go @@ -0,0 +1,60 @@ +package deltas + +import ( + "testing" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/stretchr/testify/assert" +) + +func TestUpDropEventReferenceSHAPrevEvents(t *testing.T) { + + cfg, ctx, close := testrig.CreateConfig(t, test.DBTypePostgres) + defer close() + + db, err := sqlutil.Open(&cfg.Global.DatabaseOptions, sqlutil.NewDummyWriter()) + assert.Nil(t, err) + assert.NotNil(t, db) + defer db.Close() + + // create the table in the old layout + _, err = db.ExecContext(ctx.Context(), ` +CREATE TABLE IF NOT EXISTS roomserver_previous_events ( + previous_event_id TEXT NOT NULL, + event_nids BIGINT[] NOT NULL, + previous_reference_sha256 BYTEA NOT NULL, + CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id, previous_reference_sha256) +);`) + assert.Nil(t, err) + + // create the events table as well, slimmed down with one eventNID + _, err = db.ExecContext(ctx.Context(), ` +CREATE SEQUENCE IF NOT EXISTS roomserver_event_nid_seq; +CREATE TABLE IF NOT EXISTS roomserver_events ( + event_nid BIGINT PRIMARY KEY DEFAULT nextval('roomserver_event_nid_seq'), + room_nid BIGINT NOT NULL +); + +INSERT INTO roomserver_events (event_nid, room_nid) VALUES (1, 1) +`) + assert.Nil(t, err) + + // insert duplicate prev events with different event_nids + stmt, err := db.PrepareContext(ctx.Context(), `INSERT INTO roomserver_previous_events (previous_event_id, event_nids, previous_reference_sha256) VALUES ($1, $2, $3)`) + assert.Nil(t, err) + assert.NotNil(t, stmt) + _, err = stmt.ExecContext(ctx.Context(), "1", pq.Array([]int64{1, 2}), "a") + assert.Nil(t, err) + _, err = stmt.ExecContext(ctx.Context(), "1", pq.Array([]int64{1, 2, 3}), "b") + assert.Nil(t, err) + // execute the migration + txn, err := db.Begin() + assert.Nil(t, err) + assert.NotNil(t, txn) + defer txn.Rollback() + err = UpDropEventReferenceSHAPrevEvents(ctx.Context(), txn) + assert.NoError(t, err) +} diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c935608a51..a00b4b1d76 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -22,10 +22,9 @@ import ( "sort" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -62,9 +61,6 @@ CREATE TABLE IF NOT EXISTS roomserver_events ( -- Needed for state resolution. -- An event may only appear in this table once. event_id TEXT NOT NULL CONSTRAINT roomserver_event_id_unique UNIQUE, - -- The sha256 reference hash for the event. - -- Needed for setting reference hashes when sending new events. - reference_sha256 BYTEA NOT NULL, -- A list of numeric IDs for events that can authenticate this event. auth_event_nids BIGINT[] NOT NULL, is_rejected BOOLEAN NOT NULL DEFAULT FALSE @@ -75,10 +71,10 @@ CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_event ` const insertEventSQL = "" + - "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + + "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, auth_event_nids, depth, is_rejected)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7)" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" + - " SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = TRUE" + + " SET is_rejected = $7 WHERE e.event_id = $4 AND e.is_rejected = TRUE" + " RETURNING event_nid, state_snapshot_nid" const selectEventSQL = "" + @@ -130,12 +126,9 @@ const selectEventIDSQL = "" + "SELECT event_id FROM roomserver_events WHERE event_nid = $1" const bulkSelectStateAtEventAndReferenceSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id" + " FROM roomserver_events WHERE event_nid = ANY($1)" -const bulkSelectEventReferenceSQL = "" + - "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid = ANY($1)" - const bulkSelectEventIDSQL = "" + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid = ANY($1)" @@ -167,7 +160,6 @@ type eventStatements struct { updateEventSentToOutputStmt *sql.Stmt selectEventIDStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt - bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt bulkSelectUnsentEventNIDStmt *sql.Stmt @@ -178,7 +170,18 @@ type eventStatements struct { func CreateEventsTable(db *sql.DB) error { _, err := db.Exec(eventsSchema) - return err + if err != nil { + return err + } + + m := sqlutil.NewMigrator(db) + m.AddMigrations([]sqlutil.Migration{ + { + Version: "roomserver: drop column reference_sha from roomserver_events", + Up: deltas.UpDropEventReferenceSHAEvents, + }, + }...) + return m.Up(context.Background()) } func PrepareEventsTable(db *sql.DB) (tables.Events, error) { @@ -197,7 +200,6 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, {&s.selectEventIDStmt, selectEventIDSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, - {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, @@ -214,7 +216,6 @@ func (s *eventStatements) InsertEvent( eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, eventID string, - referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, @@ -224,7 +225,7 @@ func (s *eventStatements) InsertEvent( stmt := sqlutil.TxStmt(txn, s.insertEventStmt) err := stmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + eventID, eventNIDsAsArray(authEventNIDs), depth, isRejected, ).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err @@ -441,11 +442,10 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( eventNID int64 stateSnapshotNID int64 eventID string - eventSHA256 []byte ) for ; rows.Next(); i++ { if err = rows.Scan( - &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, + &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, ); err != nil { return nil, err } @@ -455,32 +455,6 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( result.EventNID = types.EventNID(eventNID) result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID) result.EventID = eventID - result.EventSHA256 = eventSHA256 - } - if err = rows.Err(); err != nil { - return nil, err - } - if i != len(eventNIDs) { - return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) - } - return results, nil -} - -func (s *eventStatements) BulkSelectEventReference( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, -) ([]gomatrixserverlib.EventReference, error) { - rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed") - results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) - i := 0 - for ; rows.Next(); i++ { - result := &results[i] - if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil { - return nil, err - } } if err = rows.Err(); err != nil { return nil, err diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index d774b78929..835a43b2d9 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -21,13 +21,13 @@ import ( "fmt" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) const membershipSchema = ` @@ -450,7 +450,7 @@ func (s *membershipStatements) SelectLocalServerInRoom( func (s *membershipStatements) SelectServerInRoom( ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, serverName gomatrixserverlib.ServerName, + roomNID types.RoomNID, serverName spec.ServerName, ) (bool, error) { var nid types.RoomNID stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index 26999a290f..ceb5e26bab 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -32,11 +33,9 @@ const previousEventSchema = ` CREATE TABLE IF NOT EXISTS roomserver_previous_events ( -- The string event ID taken from the prev_events key of an event. previous_event_id TEXT NOT NULL, - -- The SHA256 reference hash taken from the prev_events key of an event. - previous_reference_sha256 BYTEA NOT NULL, -- A list of numeric event IDs of events that reference this prev_event. event_nids BIGINT[] NOT NULL, - CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id, previous_reference_sha256) + CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id) ); ` @@ -47,17 +46,17 @@ CREATE TABLE IF NOT EXISTS roomserver_previous_events ( // The lock is necessary to avoid data races when checking whether an event is already referenced by another event. const insertPreviousEventSQL = "" + "INSERT INTO roomserver_previous_events" + - " (previous_event_id, previous_reference_sha256, event_nids)" + - " VALUES ($1, $2, array_append('{}'::bigint[], $3))" + + " (previous_event_id, event_nids)" + + " VALUES ($1, array_append('{}'::bigint[], $2))" + " ON CONFLICT ON CONSTRAINT roomserver_previous_event_id_unique" + - " DO UPDATE SET event_nids = array_append(roomserver_previous_events.event_nids, $3)" + - " WHERE $3 != ALL(roomserver_previous_events.event_nids)" + " DO UPDATE SET event_nids = array_append(roomserver_previous_events.event_nids, $2)" + + " WHERE $2 != ALL(roomserver_previous_events.event_nids)" // Check if the event is referenced by another event in the table. // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. const selectPreviousEventExistsSQL = "" + "SELECT 1 FROM roomserver_previous_events" + - " WHERE previous_event_id = $1 AND previous_reference_sha256 = $2" + " WHERE previous_event_id = $1" type previousEventStatements struct { insertPreviousEventStmt *sql.Stmt @@ -66,7 +65,18 @@ type previousEventStatements struct { func CreatePrevEventsTable(db *sql.DB) error { _, err := db.Exec(previousEventSchema) - return err + if err != nil { + return err + } + + m := sqlutil.NewMigrator(db) + m.AddMigrations([]sqlutil.Migration{ + { + Version: "roomserver: drop column reference_sha from roomserver_prev_events", + Up: deltas.UpDropEventReferenceSHAPrevEvents, + }, + }...) + return m.Up(context.Background()) } func PreparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { @@ -82,12 +92,11 @@ func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, - previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) _, err := stmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ctx, previousEventID, int64(eventNID), ) return err } @@ -95,9 +104,9 @@ func (s *previousEventStatements) InsertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. func (s *previousEventStatements) SelectPreviousEventExists( - ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, + ctx context.Context, txn *sql.Tx, eventID string, ) error { var ok int64 stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) + return stmt.QueryRowContext(ctx, eventID).Scan(&ok) } diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 0e83cfc256..32ed06a131 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -205,26 +205,30 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility( ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, -) (map[string]*gomatrixserverlib.HeaderedEvent, error) { +) (map[string]*types.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.bulktSelectMembershipForHistoryVisibilityStmt) rows, err := stmt.QueryContext(ctx, userNID, pq.Array(eventIDs), roomInfo.RoomNID) if err != nil { return nil, err } defer rows.Close() // nolint: errcheck - result := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) + result := make(map[string]*types.HeaderedEvent, len(eventIDs)) var evJson []byte var eventID string var membershipEventID string - knownEvents := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) + knownEvents := make(map[string]*types.HeaderedEvent, len(eventIDs)) + verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) + if err != nil { + return nil, err + } for rows.Next() { if err = rows.Scan(&eventID, &membershipEventID, &evJson); err != nil { return nil, err } if len(evJson) == 0 { - result[eventID] = &gomatrixserverlib.HeaderedEvent{} + result[eventID] = &types.HeaderedEvent{} continue } // If we already know this event, don't try to marshal the json again @@ -232,13 +236,13 @@ func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility( result[eventID] = ev continue } - event, err := gomatrixserverlib.NewEventFromTrustedJSON(evJson, false, roomInfo.RoomVersion) + event, err := verImpl.NewEventFromTrustedJSON(evJson, false) if err != nil { - result[eventID] = &gomatrixserverlib.HeaderedEvent{} + result[eventID] = &types.HeaderedEvent{} // not fatal continue } - he := event.Headered(roomInfo.RoomVersion) + he := &types.HeaderedEvent{PDU: event} result[eventID] = he knownEvents[membershipEventID] = he } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 19cde54105..453ff45da5 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -131,6 +131,9 @@ func (d *Database) create(db *sql.DB) error { if err := CreateRedactionsTable(db); err != nil { return err } + if err := CreateUserRoomKeysTable(db); err != nil { + return err + } return nil } @@ -192,6 +195,11 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + userRoomKeys, err := PrepareUserRoomKeysTable(db) + if err != nil { + return err + } + d.Database = shared.Database{ DB: db, EventDatabase: shared.EventDatabase{ @@ -215,6 +223,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room MembershipTable: membership, PublishedTable: published, Purge: purge, + UserRoomKeyTable: userRoomKeys, } return nil } diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go new file mode 100644 index 0000000000..202b0abc10 --- /dev/null +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -0,0 +1,152 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "crypto/ed25519" + "database/sql" + "errors" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +const userRoomKeysSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( + user_nid INTEGER NOT NULL, + room_nid INTEGER NOT NULL, + pseudo_id_key BYTEA NULL, -- may be null for users not local to the server + pseudo_id_pub_key BYTEA NOT NULL, + CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) +); +` + +const insertUserRoomPrivateKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4) + ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key + RETURNING (pseudo_id_key) +` + +const insertUserRoomPublicKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3) + ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_pub_key = $3 + RETURNING (pseudo_id_pub_key) +` + +const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + +const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + +const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` + +type userRoomKeysStatements struct { + insertUserRoomPrivateKeyStmt *sql.Stmt + insertUserRoomPublicKeyStmt *sql.Stmt + selectUserRoomKeyStmt *sql.Stmt + selectUserRoomPublicKeyStmt *sql.Stmt + selectUserNIDsStmt *sql.Stmt +} + +func CreateUserRoomKeysTable(db *sql.DB) error { + _, err := db.Exec(userRoomKeysSchema) + return err +} + +func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { + s := &userRoomKeysStatements{} + return s, sqlutil.StatementList{ + {&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL}, + {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, + {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, + {&s.selectUserNIDsStmt, selectUserNIDsSQL}, + }.Prepare(db) +} + +func (s *userRoomKeysStatements) InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPublicKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PrivateKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt) + var result ed25519.PrivateKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + +func (s *userRoomKeysStatements) SelectUserRoomPublicKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PublicKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt) + var result ed25519.PublicKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + +func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) + + roomNIDs := make([]types.RoomNID, 0, len(senderKeys)) + var senders [][]byte + for roomNID := range senderKeys { + roomNIDs = append(roomNIDs, roomNID) + for _, key := range senderKeys[roomNID] { + senders = append(senders, key) + } + } + rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(senders)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + + result := make(map[string]types.UserRoomKeyPair, len(senders)+len(roomNIDs)) + var publicKey []byte + userRoomKeyPair := types.UserRoomKeyPair{} + for rows.Next() { + if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { + return nil, err + } + result[spec.Base64Bytes(publicKey).Encode()] = userRoomKeyPair + } + return result, rows.Err() +} diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index f9c889cb15..a96e870721 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -101,7 +101,7 @@ func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event * var inserted bool // Did the query result in a membership change? var retired []string // Did we retire any updates in the process? return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, string(event.SenderID())) if err != nil { return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index dc1db0825c..70672a33e1 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -104,20 +104,11 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { return u.currentStateSnapshotNID } -// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer -func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) - } - } - return nil - }) -} - -func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { - return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs) +func (u *RoomUpdater) Events(ctx context.Context, _ gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) { + if u.roomInfo == nil { + return nil, types.ErrorInvalidRoomInfo + } + return u.d.events(ctx, u.txn, u.roomInfo.RoomVersion, eventNIDs) } func (u *RoomUpdater) SnapshotNIDFromEventID( @@ -200,8 +191,8 @@ func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInf } // IsReferenced implements types.RoomRecentEventsUpdater -func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) +func (u *RoomUpdater) IsReferenced(eventID string) (bool, error) { + err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventID) if err == nil { return true, nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index d40ef4b632..3c8b69c32f 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -2,13 +2,19 @@ package shared import ( "context" + "crypto/ed25519" "database/sql" "encoding/json" + "errors" "fmt" "sort" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/caching" @@ -40,6 +46,7 @@ type Database struct { MembershipTable tables.Membership PublishedTable tables.Published Purge tables.Purge + UserRoomKeyTable tables.UserRoomKeys GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } @@ -62,7 +69,7 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) GetMembershipForHistoryVisibility( ctx context.Context, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, -) (map[string]*gomatrixserverlib.HeaderedEvent, error) { +) (map[string]*types.HeaderedEvent, error) { return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...) } @@ -275,17 +282,17 @@ func (d *Database) addState( var found bool for i := len(state) - 1; i >= 0; i-- { found = false + blocksLoop: for _, events := range blocks { for _, event := range events { if state[i].EventNID == event { found = true - break + break blocksLoop } } } if found { state = append(state[:i], state[i+1:]...) - i-- } } } @@ -391,18 +398,19 @@ func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo nids = append(nids, nid.EventNID) } - return d.events(ctx, txn, roomInfo, nids) + if roomInfo == nil { + return nil, types.ErrorInvalidRoomInfo + } + return d.events(ctx, txn, roomInfo.RoomVersion, nids) } -func (d *Database) LatestEventIDs( - ctx context.Context, roomNID types.RoomNID, -) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { +func (d *Database) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) (references []string, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { var eventNIDs []types.EventNID eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, nil, roomNID) if err != nil { return } - references, err = d.EventsTable.BulkSelectEventReference(ctx, nil, eventNIDs) + eventNIDMap, err := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { return } @@ -410,6 +418,9 @@ func (d *Database) LatestEventIDs( if err != nil { return } + for _, eventID := range eventNIDMap { + references = append(references, eventID) + } return } @@ -480,14 +491,14 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { }) } -func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, membershipState tables.MembershipState, stillInRoom, isRoomforgotten bool, err error) { +func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { var requestSenderUserNID types.EventStateKeyNID err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID) + requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, string(requestSenderID)) return err }) if err != nil { - return 0, 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err) + return 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err) } senderMembershipEventNID, senderMembership, isRoomforgotten, err := @@ -496,12 +507,12 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req ) if err == sql.ErrNoRows { // The user has never been a member of that room - return 0, 0, false, false, nil + return 0, false, false, nil } else if err != nil { return } - return senderMembershipEventNID, senderMembership, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil + return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil } func (d *Database) GetMembershipEventNIDsForRoom( @@ -530,19 +541,15 @@ func (d *Database) GetInvitesForUser( return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) } -func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { - return d.events(ctx, nil, roomInfo, eventNIDs) +func (d *EventDatabase) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) { + return d.events(ctx, nil, roomVersion, eventNIDs) } func (d *EventDatabase) events( - ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs, + ctx context.Context, txn *sql.Tx, roomVersion gomatrixserverlib.RoomVersion, inputEventNIDs types.EventNIDs, ) ([]types.Event, error) { - if roomInfo == nil { // this should never happen - return nil, fmt.Errorf("unable to parse events without roomInfo") - } - sort.Sort(inputEventNIDs) - events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs)) + events := make(map[types.EventNID]gomatrixserverlib.PDU, len(inputEventNIDs)) eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs)) for _, nid := range inputEventNIDs { if event, ok := d.Cache.GetRoomServerEvent(nid); ok && event != nil { @@ -561,7 +568,7 @@ func (d *EventDatabase) events( } results = append(results, types.Event{ EventNID: nid, - Event: event, + PDU: event, }) } if !redactionsArePermanent { @@ -578,16 +585,21 @@ func (d *EventDatabase) events( eventIDs = map[types.EventNID]string{} } + verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) + if err != nil { + return nil, err + } + for _, eventJSON := range eventJSONs { redacted := gjson.GetBytes(eventJSON.EventJSON, "unsigned.redacted_because").Exists() - events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( - eventIDs[eventJSON.EventNID], eventJSON.EventJSON, redacted, roomInfo.RoomVersion, + events[eventJSON.EventNID], err = verImpl.NewEventFromTrustedJSONWithEventID( + eventIDs[eventJSON.EventNID], eventJSON.EventJSON, redacted, ) if err != nil { return nil, err } if event := events[eventJSON.EventNID]; event != nil { - d.Cache.StoreRoomServerEvent(eventJSON.EventNID, event) + d.Cache.StoreRoomServerEvent(eventJSON.EventNID, &types.HeaderedEvent{PDU: event}) } } results := make([]types.Event, 0, len(inputEventNIDs)) @@ -598,7 +610,7 @@ func (d *EventDatabase) events( } results = append(results, types.Event{ EventNID: nid, - Event: event, + PDU: event, }) } if !redactionsArePermanent { @@ -651,8 +663,28 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) } +func (d *Database) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) { + // This should already be checked, let's check it anyway. + _, err = gomatrixserverlib.GetRoomVersion(roomVersion) + if err != nil { + return 0, err + } + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err = d.assignRoomNID(ctx, txn, roomID.String(), roomVersion) + if err != nil { + return err + } + return nil + }) + if err != nil { + return 0, err + } + // Not setting caches, as assignRoomNID already does this + return roomNID, err +} + // GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID. -func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (roomInfo *types.RoomInfo, err error) { +func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (roomInfo *types.RoomInfo, err error) { // Get the default room version. If the client doesn't supply a room_version // then we will use our configured default to create the room. // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom @@ -663,13 +695,17 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserve if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) } - if roomVersion == "" { - rv, ok := d.Cache.GetRoomVersion(event.RoomID()) - if ok { - roomVersion = rv - } + + roomNID, nidOK := d.Cache.GetRoomServerRoomNID(event.RoomID()) + cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(event.RoomID()) + // if we found both, the roomNID and version in our cache, no need to query the database + if nidOK && versionOK { + return &types.RoomInfo{ + RoomNID: roomNID, + RoomVersion: cachedRoomVersion, + }, nil } - var roomNID types.RoomNID + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) if err != nil { @@ -686,6 +722,22 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserve }, err } +func (d *Database) GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) { + cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(roomID) + if versionOK { + return cachedRoomVersion, nil + } + + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return "", err + } + if roomInfo == nil { + return "", nil + } + return roomInfo.RoomVersion, nil +} + func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil { @@ -715,7 +767,7 @@ func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKe } func (d *EventDatabase) StoreEvent( - ctx context.Context, event *gomatrixserverlib.Event, + ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool, ) (types.EventNID, types.StateAtEvent, error) { @@ -733,7 +785,6 @@ func (d *EventDatabase) StoreEvent( eventTypeNID, eventStateKeyNID, event.EventID(), - event.EventReference().EventSHA256, authEventNIDs, event.Depth(), isRejected, @@ -753,7 +804,7 @@ func (d *EventDatabase) StoreEvent( return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } - if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { + if prevEvents := event.PrevEventIDs(); len(prevEvents) > 0 { // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This // function only does SELECTs though so the created txn (at this point) is just a read txn like @@ -761,8 +812,8 @@ func (d *EventDatabase) StoreEvent( // to do writes however then this will need to go inside `Writer.Do`. // The following is a copy of RoomUpdater.StorePreviousEvents - for _, ref := range prevEvents { - if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + for _, eventID := range prevEvents { + if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, eventID, eventNID); err != nil { return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) } } @@ -807,7 +858,7 @@ func (d *Database) GetPublishedRooms(ctx context.Context, networkID string, incl } func (d *Database) MissingAuthPrevEvents( - ctx context.Context, e *gomatrixserverlib.Event, + ctx context.Context, e gomatrixserverlib.PDU, ) (missingAuth, missingPrev []string, err error) { authEventNIDs, err := d.EventNIDs(ctx, e.AuthEventIDs()) if err != nil { @@ -899,13 +950,13 @@ func (d *EventDatabase) assignStateKeyNID( return eventStateKeyNID, err } -func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( +func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) ( gomatrixserverlib.RoomVersion, error, ) { var err error var roomVersion gomatrixserverlib.RoomVersion // Look for m.room.create events. - if event.Type() != gomatrixserverlib.MRoomCreate { + if event.Type() != spec.MRoomCreate { return gomatrixserverlib.RoomVersion(""), nil } roomVersion = gomatrixserverlib.RoomVersionV1 @@ -922,6 +973,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( return roomVersion, err } +// nolint:gocyclo // MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec: // "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid." // https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events @@ -939,8 +991,9 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( // // Returns the redaction event and the redacted event if this call resulted in a redaction. func (d *EventDatabase) MaybeRedactEvent( - ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, plResolver state.PowerLevelResolver, -) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) { + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, + querier api.QuerySenderIDAPI, +) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) { var ( redactionEvent, redactedEvent *types.Event err error @@ -949,7 +1002,7 @@ func (d *EventDatabase) MaybeRedactEvent( ) wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil + isRedactionEvent := event.Type() == spec.MRoomRedaction && event.StateKey() == nil if isRedactionEvent { // an event which redacts itself should be ignored if event.EventID() == event.Redacts() { @@ -979,8 +1032,21 @@ func (d *EventDatabase) MaybeRedactEvent( return nil } - _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender()) - _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender()) + var validRoomID *spec.RoomID + validRoomID, err = spec.NewRoomID(redactedEvent.RoomID()) + if err != nil { + return err + } + sender1Domain := "" + sender1, err1 := querier.QueryUserIDForSender(ctx, *validRoomID, redactedEvent.SenderID()) + if err1 == nil { + sender1Domain = string(sender1.Domain()) + } + sender2Domain := "" + sender2, err2 := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID()) + if err2 == nil { + sender2Domain = string(sender2.Domain()) + } var powerlevels *gomatrixserverlib.PowerLevelContent powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID()) if err != nil { @@ -988,9 +1054,9 @@ func (d *EventDatabase) MaybeRedactEvent( } switch { - case powerlevels.UserLevel(redactionEvent.Sender()) >= powerlevels.Redact: + case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. - case sender1 == sender2: + case sender1Domain != "" && sender2Domain != "" && sender1Domain == sender2Domain: // 2. The domain of the redaction event’s sender matches that of the original event’s sender. default: ignoreRedaction = true @@ -1034,30 +1100,30 @@ func (d *EventDatabase) MaybeRedactEvent( if ignoreRedaction || redactionEvent == nil || redactedEvent == nil { return nil, nil, nil } - return redactionEvent.Event, redactedEvent.Event, nil + return redactionEvent.PDU, redactedEvent.PDU, nil } // loadRedactionPair returns both the redaction event and the redacted event, else nil. func (d *EventDatabase) loadRedactionPair( - ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, + ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, ) (*types.Event, *types.Event, bool, error) { var redactionEvent, redactedEvent *types.Event var info *tables.RedactionInfo var err error - isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil + isRedactionEvent := event.Type() == spec.MRoomRedaction && event.StateKey() == nil var eventBeingRedacted string if isRedactionEvent { eventBeingRedacted = event.Redacts() redactionEvent = &types.Event{ EventNID: eventNID, - Event: event, + PDU: event, } } else { eventBeingRedacted = event.EventID() // maybe, we'll see if we have info redactedEvent = &types.Event{ EventNID: eventNID, - Event: event, + PDU: event, } } @@ -1097,7 +1163,10 @@ func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo, if len(nids) == 0 { return nil } - evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID}) + if roomInfo == nil { + return nil + } + evs, err := d.Events(ctx, roomInfo.RoomVersion, []types.EventNID{nids[eventID].EventNID}) if err != nil { return nil } @@ -1107,7 +1176,7 @@ func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo, return &evs[0] } -func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) { +func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]gomatrixserverlib.PDU, error) { eventStates, err := d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, []string{eventID}) if err != nil { return nil, err @@ -1124,13 +1193,17 @@ func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *type if err != nil { eventIDs = map[types.EventNID]string{} } - events := make([]*gomatrixserverlib.Event, 0, len(eventNIDs)) + verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) + if err != nil { + return nil, err + } + events := make([]gomatrixserverlib.PDU, 0, len(eventNIDs)) for _, eventNID := range eventNIDs { data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{eventNID}) if err != nil { return nil, err } - ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[eventNID], data[0].EventJSON, false, roomInfo.RoomVersion) + ev, err := verImpl.NewEventFromTrustedJSONWithEventID(eventIDs[eventNID], data[0].EventJSON, false) if err != nil { return nil, err } @@ -1142,7 +1215,7 @@ func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *type // GetStateEvent returns the current state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error -func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { +func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) { roomInfo, err := d.roomInfo(ctx, nil, roomID) if err != nil { return nil, err @@ -1154,7 +1227,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s if roomInfo.IsStub() { return nil, nil } - eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + eventTypeNID, err := d.GetOrCreateEventTypeNID(ctx, evType) if err == sql.ErrNoRows { // No rooms have an event of this type, otherwise we'd have an event type NID return nil, nil @@ -1162,7 +1235,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s if err != nil { return nil, err } - stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) + stateKeyNID, err := d.GetOrCreateEventStateKeyNID(ctx, &stateKey) if err == sql.ErrNoRows { // No rooms have a state event with this state key, otherwise we'd have an state key NID return nil, nil @@ -1180,6 +1253,10 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s eventNIDs = append(eventNIDs, e.EventNID) } } + verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) + if err != nil { + return nil, err + } eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} @@ -1187,6 +1264,10 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s // return the event requested for _, e := range entries { if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { + cachedEvent, ok := d.Cache.GetRoomServerEvent(e.EventNID) + if ok { + return &types.HeaderedEvent{PDU: cachedEvent}, nil + } data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID}) if err != nil { return nil, err @@ -1194,11 +1275,11 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s if len(data) == 0 { return nil, fmt.Errorf("GetStateEvent: no json for event nid %d", e.EventNID) } - ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[e.EventNID], data[0].EventJSON, false, roomInfo.RoomVersion) + ev, err := verImpl.NewEventFromTrustedJSONWithEventID(eventIDs[e.EventNID], data[0].EventJSON, false) if err != nil { return nil, err } - return ev.Headered(roomInfo.RoomVersion), nil + return &types.HeaderedEvent{PDU: ev}, nil } } @@ -1207,7 +1288,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s // Same as GetStateEvent but returns all matching state events with this event type. Returns no error // if there are no events with this event type. -func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { +func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*types.HeaderedEvent, error) { roomInfo, err := d.roomInfo(ctx, nil, roomID) if err != nil { return nil, err @@ -1249,13 +1330,17 @@ func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evTy if len(eventPairs) == 0 { return nil, nil } - var result []*gomatrixserverlib.HeaderedEvent + verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) + if err != nil { + return nil, err + } + var result []*types.HeaderedEvent for _, pair := range eventPairs { - ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false, roomInfo.RoomVersion) + ev, err := verImpl.NewEventFromTrustedJSONWithEventID(eventIDs[pair.EventNID], pair.EventJSON, false) if err != nil { return nil, err } - result = append(result, ev.Headered(roomInfo.RoomVersion)) + result = append(result, &types.HeaderedEvent{PDU: ev}) } return result, nil @@ -1306,7 +1391,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } // we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which // isn't a failure. - eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes) + eventTypeNIDMap, err := d.eventTypeNIDs(ctx, nil, eventTypes) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err) } @@ -1371,7 +1456,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu result := make([]tables.StrippedEvent, len(events)) for i := range events { roomVer := eventNIDToVer[events[i].EventNID] - ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[events[i].EventNID], events[i].EventJSON, false, roomVer) + verImpl, err := gomatrixserverlib.GetRoomVersion(roomVer) + if err != nil { + return nil, err + } + ev, err := verImpl.NewEventFromTrustedJSONWithEventID(eventIDs[events[i].EventNID], events[i].EventJSON, false) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event NID %v : %w", events[i].EventNID, err) } @@ -1379,7 +1468,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu EventType: ev.Type(), RoomID: ev.RoomID(), StateKey: *ev.StateKey(), - ContentValue: tables.ExtractContentValue(ev.Headered(roomVer)), + ContentValue: tables.ExtractContentValue(&types.HeaderedEvent{PDU: ev}), } } @@ -1469,7 +1558,7 @@ func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomN } // GetServerInRoom returns true if we think a server is in a given room or false otherwise. -func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) { return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName) } @@ -1557,6 +1646,173 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS }) } +// InsertUserRoomPrivatePublicKey inserts a new user room key for the given user and room. +// Returns the newly inserted private key or an existing private key. If there is +// an error talking to the database, returns that error. +func (d *Database) InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return eventutil.ErrRoomNoExists{} + } + + var iErr error + result, iErr = d.UserRoomKeyTable.InsertUserRoomPrivatePublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key) + return iErr + }) + return result, err +} + +// InsertUserRoomPublicKey inserts a new user room key for the given user and room. +// Returns the newly inserted public key or an existing public key. If there is +// an error talking to the database, returns that error. +func (d *Database) InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return eventutil.ErrRoomNoExists{} + } + + var iErr error + result, iErr = d.UserRoomKeyTable.InsertUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID, key) + return iErr + }) + return result, err +} + +// SelectUserRoomPrivateKey queries the users room private key. +// If no key exists, returns no key and no error. Otherwise returns +// the key and a database error, if any. +// TODO: Cache this? +func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return eventutil.ErrRoomNoExists{} + } + + key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID) + if !errors.Is(sErr, sql.ErrNoRows) { + return sErr + } + return nil + }) + return +} + +// SelectUserRoomPublicKey queries the users room public key. +// If no key exists, returns no key and no error. Otherwise returns +// the key and a database error, if any. +func (d *Database) SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return nil + } + + key, sErr = d.UserRoomKeyTable.SelectUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID) + if !errors.Is(sErr, sql.ErrNoRows) { + return sErr + } + return nil + }) + return +} + +// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID +func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { + result = make(map[spec.RoomID]map[string]string, len(publicKeys)) + + // map all roomIDs to roomNIDs + query := make(map[types.RoomNID][]ed25519.PublicKey) + rooms := make(map[types.RoomNID]spec.RoomID) + for roomID, keys := range publicKeys { + roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String()) + if !ok { + roomInfo, rErr := d.roomInfo(ctx, nil, roomID.String()) + if rErr != nil { + return nil, rErr + } + if roomInfo == nil { + logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String()) + continue + } + roomNID = roomInfo.RoomNID + } + + query[roomNID] = keys + rooms[roomNID] = roomID + } + + // get the user room key pars + userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, nil, query) + if sErr != nil { + return nil, sErr + } + nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap)) + for _, nid := range userRoomKeyPairMap { + nids = append(nids, nid.EventStateKeyNID) + } + // get the userIDs + nidMap, seErr := d.EventStateKeys(ctx, nids) + if seErr != nil { + return nil, seErr + } + + // build the result map (roomID -> map publicKey -> userID) + for publicKey, userRoomKeyPair := range userRoomKeyPairMap { + userID := nidMap[userRoomKeyPair.EventStateKeyNID] + roomID := rooms[userRoomKeyPair.RoomNID] + resMap, exists := result[roomID] + if !exists { + resMap = map[string]string{} + } + resMap[publicKey] = userID + result[roomID] = resMap + } + return result, err +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package! diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 941e848021..612e4ef069 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -2,11 +2,16 @@ package shared_test import ( "context" + "crypto/ed25519" "testing" "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" + ed255192 "golang.org/x/crypto/ed25519" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres" @@ -23,41 +28,62 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat connStr, clearDB := test.PrepareDBConnectionString(t, dbType) dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)} - db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter()) + writer := sqlutil.NewExclusiveWriter() + db, err := sqlutil.Open(dbOpts, writer) assert.NoError(t, err) var membershipTable tables.Membership var stateKeyTable tables.EventStateKeys + var userRoomKeys tables.UserRoomKeys + var roomsTable tables.Rooms switch dbType { case test.DBTypePostgres: + err = postgres.CreateRoomsTable(db) + assert.NoError(t, err) err = postgres.CreateEventStateKeysTable(db) assert.NoError(t, err) err = postgres.CreateMembershipTable(db) assert.NoError(t, err) + err = postgres.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + roomsTable, err = postgres.PrepareRoomsTable(db) + assert.NoError(t, err) membershipTable, err = postgres.PrepareMembershipTable(db) assert.NoError(t, err) stateKeyTable, err = postgres.PrepareEventStateKeysTable(db) + assert.NoError(t, err) + userRoomKeys, err = postgres.PrepareUserRoomKeysTable(db) case test.DBTypeSQLite: + err = sqlite3.CreateRoomsTable(db) + assert.NoError(t, err) err = sqlite3.CreateEventStateKeysTable(db) assert.NoError(t, err) err = sqlite3.CreateMembershipTable(db) assert.NoError(t, err) + err = sqlite3.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + roomsTable, err = sqlite3.PrepareRoomsTable(db) + assert.NoError(t, err) membershipTable, err = sqlite3.PrepareMembershipTable(db) assert.NoError(t, err) stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db) + assert.NoError(t, err) + userRoomKeys, err = sqlite3.PrepareUserRoomKeysTable(db) } assert.NoError(t, err) cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) - evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache} + evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache, Writer: writer} return &shared.Database{ - DB: db, - EventDatabase: evDb, - MembershipTable: membershipTable, - Writer: sqlutil.NewExclusiveWriter(), - Cache: cache, + DB: db, + EventDatabase: evDb, + MembershipTable: membershipTable, + UserRoomKeyTable: userRoomKeys, + RoomsTable: roomsTable, + Writer: writer, + Cache: cache, }, func() { clearDB() err = db.Close() @@ -97,3 +123,106 @@ func Test_GetLeftUsers(t *testing.T) { assert.ElementsMatch(t, expectedUserIDs, leftUsers) }) } + +func TestUserRoomKeys(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + userID, err := spec.NewUserID(alice.ID, true) + assert.NoError(t, err) + roomID, err := spec.NewRoomID(room.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + // create a room NID so we can query the room + _, err = db.RoomsTable.InsertRoomNID(ctx, nil, roomID.String(), gomatrixserverlib.RoomVersionV10) + assert.NoError(t, err) + doesNotExist, err := spec.NewRoomID("!doesnotexist:localhost") + assert.NoError(t, err) + _, err = db.RoomsTable.InsertRoomNID(ctx, nil, doesNotExist.String(), gomatrixserverlib.RoomVersionV10) + assert.NoError(t, err) + + _, key, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + + gotKey, err := db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + // again, this shouldn't result in an error, but return the existing key + _, key2, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotKey, err = db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key2) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID) + assert.NoError(t, err) + assert.Equal(t, key, gotKey) + pubKey, err := db.SelectUserRoomPublicKey(context.Background(), *userID, *roomID) + assert.NoError(t, err) + assert.Equal(t, key.Public(), pubKey) + + // Key doesn't exist, we shouldn't get anything back + gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist) + assert.NoError(t, err) + assert.Nil(t, gotKey) + pubKey, err = db.SelectUserRoomPublicKey(context.Background(), *userID, *doesNotExist) + assert.NoError(t, err) + assert.Nil(t, pubKey) + + queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{ + *roomID: {key.Public().(ed25519.PublicKey)}, + } + + userIDs, err := db.SelectUserIDsForPublicKeys(ctx, queryUserIDs) + assert.NoError(t, err) + wantKeys := map[spec.RoomID]map[string]string{ + *roomID: { + spec.Base64Bytes(key.Public().(ed25519.PublicKey)).Encode(): userID.String(), + }, + } + assert.Equal(t, wantKeys, userIDs) + + // insert key that came in over federation + var gotPublicKey, key4 ed255192.PublicKey + key4, _, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotPublicKey, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *doesNotExist, key4) + assert.NoError(t, err) + assert.Equal(t, key4, gotPublicKey) + + // test invalid room + reallyDoesNotExist, err := spec.NewRoomID("!reallydoesnotexist:localhost") + assert.NoError(t, err) + _, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *reallyDoesNotExist, key4) + assert.Error(t, err) + _, err = db.InsertUserRoomPrivatePublicKey(context.Background(), *userID, *reallyDoesNotExist, key) + assert.Error(t, err) + }) +} + +func TestAssignRoomNID(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + roomID, err := spec.NewRoomID(room.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + nid, err := db.AssignRoomNID(ctx, *roomID, room.Version) + assert.NoError(t, err) + assert.Greater(t, nid, types.EventNID(0)) + + _, err = db.AssignRoomNID(ctx, spec.RoomID{}, "notaroomversion") + assert.Error(t, err) + }) +} diff --git a/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go b/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go new file mode 100644 index 0000000000..515bccc374 --- /dev/null +++ b/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha.go @@ -0,0 +1,146 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/util" +) + +func UpDropEventReferenceSHA(ctx context.Context, tx *sql.Tx) error { + var count int + err := tx.QueryRowContext(ctx, `SELECT count(*) FROM roomserver_events GROUP BY event_id HAVING count(event_id) > 1`). + Scan(&count) + if err != nil && err != sql.ErrNoRows { + return fmt.Errorf("failed to query duplicate event ids") + } + if count > 0 { + return fmt.Errorf("unable to drop column, as there are duplicate event ids") + } + _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_events DROP COLUMN reference_sha256;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func UpDropEventReferenceSHAPrevEvents(ctx context.Context, tx *sql.Tx) error { + // rename the table + if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_previous_events RENAME TO _roomserver_previous_events;`); err != nil { + return fmt.Errorf("tx.ExecContext: %w", err) + } + + // create new table + if _, err := tx.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS roomserver_previous_events ( + previous_event_id TEXT NOT NULL, + event_nids TEXT NOT NULL, + UNIQUE (previous_event_id) + );`); err != nil { + return fmt.Errorf("tx.ExecContext: %w", err) + } + + // figure out if there are duplicates + dupeRows, err := tx.QueryContext(ctx, `SELECT previous_event_id FROM _roomserver_previous_events GROUP BY previous_event_id HAVING count(previous_event_id) > 1`) + if err != nil { + return fmt.Errorf("failed to query duplicate event ids") + } + defer internal.CloseAndLogIfError(ctx, dupeRows, "failed to close rows") + + var prevEvents []string + var prevEventID string + for dupeRows.Next() { + if err = dupeRows.Scan(&prevEventID); err != nil { + return err + } + prevEvents = append(prevEvents, prevEventID) + } + if dupeRows.Err() != nil { + return dupeRows.Err() + } + + // if we found duplicates, check if we can combine them, e.g. they are in the same room + for _, dupeID := range prevEvents { + var dupeNIDsRows *sql.Rows + dupeNIDsRows, err = tx.QueryContext(ctx, `SELECT event_nids FROM _roomserver_previous_events WHERE previous_event_id = $1`, dupeID) + if err != nil { + return fmt.Errorf("failed to query duplicate event ids") + } + defer internal.CloseAndLogIfError(ctx, dupeNIDsRows, "failed to close rows") + var dupeNIDs []int64 + for dupeNIDsRows.Next() { + var nids pq.Int64Array + if err = dupeNIDsRows.Scan(&nids); err != nil { + return err + } + dupeNIDs = append(dupeNIDs, nids...) + } + + if dupeNIDsRows.Err() != nil { + return dupeNIDsRows.Err() + } + // dedupe NIDs + dupeNIDs = dupeNIDs[:util.SortAndUnique(nids(dupeNIDs))] + // now that we have all NIDs, check which room they belong to + var roomCount int + err = tx.QueryRowContext(ctx, `SELECT count(distinct room_nid) FROM roomserver_events WHERE event_nid IN ($1)`, pq.Array(dupeNIDs)).Scan(&roomCount) + if err != nil { + return err + } + // if the events are from different rooms, that's bad and we can't continue + if roomCount > 1 { + return fmt.Errorf("detected events (%v) referenced for different rooms (%v)", dupeNIDs, roomCount) + } + // otherwise delete the dupes + _, err = tx.ExecContext(ctx, "DELETE FROM _roomserver_previous_events WHERE previous_event_id = $1", dupeID) + if err != nil { + return fmt.Errorf("unable to delete duplicates: %w", err) + } + + // insert combined values + _, err = tx.ExecContext(ctx, "INSERT INTO _roomserver_previous_events (previous_event_id, event_nids) VALUES ($1, $2)", dupeID, pq.Array(dupeNIDs)) + if err != nil { + return fmt.Errorf("unable to insert new event NIDs: %w", err) + } + } + + // move data + if _, err = tx.ExecContext(ctx, ` +INSERT + INTO roomserver_previous_events ( + previous_event_id, event_nids + ) SELECT + previous_event_id, event_nids + FROM _roomserver_previous_events +;`); err != nil { + return fmt.Errorf("tx.ExecContext: %w", err) + } + // drop old table + _, err = tx.ExecContext(ctx, `DROP TABLE _roomserver_previous_events;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +type nids []int64 + +func (s nids) Len() int { return len(s) } +func (s nids) Less(i, j int) bool { return s[i] < s[j] } +func (s nids) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha_test.go b/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha_test.go new file mode 100644 index 0000000000..547d9703be --- /dev/null +++ b/roomserver/storage/sqlite3/deltas/20230516154000_drop_reference_sha_test.go @@ -0,0 +1,59 @@ +package deltas + +import ( + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/stretchr/testify/assert" +) + +func TestUpDropEventReferenceSHAPrevEvents(t *testing.T) { + + cfg, ctx, close := testrig.CreateConfig(t, test.DBTypeSQLite) + defer close() + + db, err := sqlutil.Open(&cfg.RoomServer.Database, sqlutil.NewExclusiveWriter()) + assert.Nil(t, err) + assert.NotNil(t, db) + defer db.Close() + + // create the table in the old layout + _, err = db.ExecContext(ctx.Context(), ` + CREATE TABLE IF NOT EXISTS roomserver_previous_events ( + previous_event_id TEXT NOT NULL, + previous_reference_sha256 BLOB, + event_nids TEXT NOT NULL, + UNIQUE (previous_event_id, previous_reference_sha256) + );`) + assert.Nil(t, err) + + // create the events table as well, slimmed down with one eventNID + _, err = db.ExecContext(ctx.Context(), ` + CREATE TABLE IF NOT EXISTS roomserver_events ( + event_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_nid INTEGER NOT NULL +); + +INSERT INTO roomserver_events (event_nid, room_nid) VALUES (1, 1) +`) + assert.Nil(t, err) + + // insert duplicate prev events with different event_nids + stmt, err := db.PrepareContext(ctx.Context(), `INSERT INTO roomserver_previous_events (previous_event_id, event_nids, previous_reference_sha256) VALUES ($1, $2, $3)`) + assert.Nil(t, err) + assert.NotNil(t, stmt) + _, err = stmt.ExecContext(ctx.Context(), "1", "{1,2}", "a") + assert.Nil(t, err) + _, err = stmt.ExecContext(ctx.Context(), "1", "{1,2,3}", "b") + assert.Nil(t, err) + + // execute the migration + txn, err := db.Begin() + assert.Nil(t, err) + assert.NotNil(t, txn) + err = UpDropEventReferenceSHAPrevEvents(ctx.Context(), txn) + defer txn.Rollback() + assert.NoError(t, err) +} diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index aacf4bc9a8..c49c6dc38a 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -19,14 +19,14 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "sort" "strings" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -41,17 +41,16 @@ const eventsSchema = ` state_snapshot_nid INTEGER NOT NULL DEFAULT 0, depth INTEGER NOT NULL, event_id TEXT NOT NULL UNIQUE, - reference_sha256 BLOB NOT NULL, auth_event_nids TEXT NOT NULL DEFAULT '[]', is_rejected BOOLEAN NOT NULL DEFAULT FALSE ); ` const insertEventSQL = ` - INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, auth_event_nids, depth, is_rejected) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT DO UPDATE - SET is_rejected = $8 WHERE is_rejected = 1 + SET is_rejected = $7 WHERE is_rejected = 1 RETURNING event_nid, state_snapshot_nid; ` @@ -100,12 +99,9 @@ const selectEventIDSQL = "" + "SELECT event_id FROM roomserver_events WHERE event_nid = $1" const bulkSelectStateAtEventAndReferenceSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id" + " FROM roomserver_events WHERE event_nid IN ($1)" -const bulkSelectEventReferenceSQL = "" + - "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)" - const bulkSelectEventIDSQL = "" + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" @@ -137,7 +133,6 @@ type eventStatements struct { updateEventSentToOutputStmt *sql.Stmt selectEventIDStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt - bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt selectEventRejectedStmt *sql.Stmt //bulkSelectEventNIDStmt *sql.Stmt @@ -147,7 +142,32 @@ type eventStatements struct { func CreateEventsTable(db *sql.DB) error { _, err := db.Exec(eventsSchema) - return err + if err != nil { + return err + } + + // check if the column exists + var cName string + migrationName := "roomserver: drop column reference_sha from roomserver_events" + err = db.QueryRowContext(context.Background(), `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'roomserver_events' AND p.name = 'reference_sha256'`).Scan(&cName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed + if err = sqlutil.InsertMigration(context.Background(), db, migrationName); err != nil { + return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + return nil + } + return err + } + + m := sqlutil.NewMigrator(db) + m.AddMigrations([]sqlutil.Migration{ + { + Version: migrationName, + Up: deltas.UpDropEventReferenceSHA, + }, + }...) + return m.Up(context.Background()) } func PrepareEventsTable(db *sql.DB) (tables.Events, error) { @@ -167,7 +187,6 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, {&s.selectEventIDStmt, selectEventIDSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, - {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, //{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, //{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, @@ -183,7 +202,6 @@ func (s *eventStatements) InsertEvent( eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, eventID string, - referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, @@ -194,7 +212,7 @@ func (s *eventStatements) InsertEvent( insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) err := insertStmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, + eventID, eventNIDsAsArray(authEventNIDs), depth, isRejected, ).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } @@ -475,11 +493,10 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( eventNID int64 stateSnapshotNID int64 eventID string - eventSHA256 []byte ) for ; rows.Next(); i++ { if err = rows.Scan( - &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, + &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, ); err != nil { return nil, err } @@ -489,43 +506,6 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( result.EventNID = types.EventNID(eventNID) result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID) result.EventID = eventID - result.EventSHA256 = eventSHA256 - } - if i != len(eventNIDs) { - return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) - } - return results, nil -} - -func (s *eventStatements) BulkSelectEventReference( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, -) ([]gomatrixserverlib.EventReference, error) { - /////////////// - iEventNIDs := make([]interface{}, len(eventNIDs)) - for k, v := range eventNIDs { - iEventNIDs[k] = v - } - selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - selectPrep, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - defer selectPrep.Close() // nolint:errcheck - /////////////// - - selectStmt := sqlutil.TxStmt(txn, selectPrep) - rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventReference: rows.close() failed") - results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) - i := 0 - for ; rows.Next(); i++ { - result := &results[i] - if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil { - return nil, err - } } if i != len(eventNIDs) { return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 8a60b359f4..977788d505 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -21,13 +21,12 @@ import ( "fmt" "strings" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) const membershipSchema = ` @@ -398,7 +397,7 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn return found, nil } -func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) { var nid types.RoomNID stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 2a146ef64b..4e59fbba7d 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -18,10 +18,12 @@ package sqlite3 import ( "context" "database/sql" + "errors" "fmt" "strings" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -34,9 +36,8 @@ import ( const previousEventSchema = ` CREATE TABLE IF NOT EXISTS roomserver_previous_events ( previous_event_id TEXT NOT NULL, - previous_reference_sha256 BLOB, event_nids TEXT NOT NULL, - UNIQUE (previous_event_id, previous_reference_sha256) + UNIQUE (previous_event_id) ); ` @@ -47,20 +48,20 @@ const previousEventSchema = ` // The lock is necessary to avoid data races when checking whether an event is already referenced by another event. const insertPreviousEventSQL = ` INSERT OR REPLACE INTO roomserver_previous_events - (previous_event_id, previous_reference_sha256, event_nids) - VALUES ($1, $2, $3) + (previous_event_id, event_nids) + VALUES ($1, $2) ` const selectPreviousEventNIDsSQL = ` SELECT event_nids FROM roomserver_previous_events - WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 + WHERE previous_event_id = $1 ` // Check if the event is referenced by another event in the table. // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. const selectPreviousEventExistsSQL = ` SELECT 1 FROM roomserver_previous_events - WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 + WHERE previous_event_id = $1 ` type previousEventStatements struct { @@ -72,7 +73,30 @@ type previousEventStatements struct { func CreatePrevEventsTable(db *sql.DB) error { _, err := db.Exec(previousEventSchema) - return err + if err != nil { + return err + } + // check if the column exists + var cName string + migrationName := "roomserver: drop column reference_sha from roomserver_prev_events" + err = db.QueryRowContext(context.Background(), `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'roomserver_previous_events' AND p.name = 'previous_reference_sha256'`).Scan(&cName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed + if err = sqlutil.InsertMigration(context.Background(), db, migrationName); err != nil { + return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + return nil + } + return err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations([]sqlutil.Migration{ + { + Version: migrationName, + Up: deltas.UpDropEventReferenceSHAPrevEvents, + }, + }...) + return m.Up(context.Background()) } func PreparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { @@ -91,13 +115,12 @@ func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, - previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { var eventNIDs string eventNIDAsString := fmt.Sprintf("%d", eventNID) selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs) + err := selectStmt.QueryRowContext(ctx, previousEventID).Scan(&eventNIDs) if err != nil && err != sql.ErrNoRows { return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err) } @@ -115,7 +138,7 @@ func (s *previousEventStatements) InsertPreviousEvent( } insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) _, err = insertStmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, eventNIDs, + ctx, previousEventID, eventNIDs, ) return err } @@ -123,9 +146,9 @@ func (s *previousEventStatements) InsertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. func (s *previousEventStatements) SelectPreviousEventExists( - ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, + ctx context.Context, txn *sql.Tx, eventID string, ) error { var ok int64 stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) + return stmt.QueryRowContext(ctx, eventID).Scan(&ok) } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index e57e1a4bfc..2edff0ba8a 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -26,7 +26,6 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -153,7 +152,7 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( return nil, tables.OptimisationNotSupportedError } -func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string) (map[string]*gomatrixserverlib.HeaderedEvent, error) { +func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string) (map[string]*types.HeaderedEvent, error) { return nil, tables.OptimisationNotSupportedError } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 89e16fc141..ef51a5b08c 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -21,14 +21,13 @@ import ( "errors" "fmt" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" ) // A Database is used to store room events and stream offsets. @@ -139,6 +138,9 @@ func (d *Database) create(db *sql.DB) error { if err := CreateRedactionsTable(db); err != nil { return err } + if err := CreateUserRoomKeysTable(db); err != nil { + return err + } return nil } @@ -200,6 +202,10 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + userRoomKeys, err := PrepareUserRoomKeysTable(db) + if err != nil { + return err + } d.Database = shared.Database{ DB: db, @@ -225,6 +231,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room PublishedTable: published, GetRoomUpdaterFn: d.GetRoomUpdater, Purge: purge, + UserRoomKeyTable: userRoomKeys, } return nil } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go new file mode 100644 index 0000000000..5d6ddc9a8e --- /dev/null +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -0,0 +1,167 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "crypto/ed25519" + "database/sql" + "errors" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +const userRoomKeysSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( + user_nid INTEGER NOT NULL, + room_nid INTEGER NOT NULL, + pseudo_id_key TEXT NULL, -- may be null for users not local to the server + pseudo_id_pub_key TEXT NOT NULL, + CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) +); +` + +const insertUserRoomKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4) + ON CONFLICT DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key + RETURNING (pseudo_id_key) +` + +const insertUserRoomPublicKeySQL = ` + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_pub_key) VALUES ($1, $2, $3) + ON CONFLICT DO UPDATE SET pseudo_id_pub_key = $3 + RETURNING (pseudo_id_pub_key) +` + +const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + +const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + +const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` + +type userRoomKeysStatements struct { + db *sql.DB + insertUserRoomPrivateKeyStmt *sql.Stmt + insertUserRoomPublicKeyStmt *sql.Stmt + selectUserRoomKeyStmt *sql.Stmt + selectUserRoomPublicKeyStmt *sql.Stmt + //selectUserNIDsStmt *sql.Stmt //prepared at runtime +} + +func CreateUserRoomKeysTable(db *sql.DB) error { + _, err := db.Exec(userRoomKeysSchema) + return err +} + +func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { + s := &userRoomKeysStatements{db: db} + return s, sqlutil.StatementList{ + {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, + {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, + {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, + //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime + }.Prepare(db) +} + +func (s *userRoomKeysStatements) InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPrivateKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomPublicKeyStmt) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + return result, err +} + +func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PrivateKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeyStmt) + var result ed25519.PrivateKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + +func (s *userRoomKeysStatements) SelectUserRoomPublicKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PublicKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt) + var result ed25519.PublicKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + +func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { + + roomNIDs := make([]any, 0, len(senderKeys)) + var senders []any + for roomNID := range senderKeys { + roomNIDs = append(roomNIDs, roomNID) + + for _, key := range senderKeys[roomNID] { + senders = append(senders, []byte(key)) + } + } + + selectSQL := strings.Replace(selectUserNIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(senders), len(senderKeys)), 1) + selectSQL = strings.Replace(selectSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) // replace $1 with the roomNIDs + + selectStmt, err := s.db.Prepare(selectSQL) + if err != nil { + return nil, err + } + + params := append(roomNIDs, senders...) + + stmt := sqlutil.TxStmt(txn, selectStmt) + defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + + result := make(map[string]types.UserRoomKeyPair, len(params)) + var publicKey []byte + userRoomKeyPair := types.UserRoomKeyPair{} + for rows.Next() { + if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { + return nil, err + } + result[spec.Base64Bytes(publicKey).Encode()] = userRoomKeyPair + } + return result, rows.Err() +} diff --git a/roomserver/storage/tables/events_table_test.go b/roomserver/storage/tables/events_table_test.go index 107af47845..5ed805648d 100644 --- a/roomserver/storage/tables/events_table_test.go +++ b/roomserver/storage/tables/events_table_test.go @@ -11,7 +11,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) @@ -48,10 +47,9 @@ func Test_EventsTable(t *testing.T) { // create some dummy data eventIDs := make([]string, 0, len(room.Events())) wantStateAtEvent := make([]types.StateAtEvent, 0, len(room.Events())) - wantEventReferences := make([]gomatrixserverlib.EventReference, 0, len(room.Events())) wantStateAtEventAndRefs := make([]types.StateAtEventAndReference, 0, len(room.Events())) for _, ev := range room.Events() { - eventNID, snapNID, err := tab.InsertEvent(ctx, nil, 1, 1, 1, ev.EventID(), ev.EventReference().EventSHA256, nil, ev.Depth(), false) + eventNID, snapNID, err := tab.InsertEvent(ctx, nil, 1, 1, 1, ev.EventID(), nil, ev.Depth(), false) assert.NoError(t, err) gotEventNID, gotSnapNID, err := tab.SelectEvent(ctx, nil, ev.EventID()) assert.NoError(t, err) @@ -75,7 +73,6 @@ func Test_EventsTable(t *testing.T) { assert.True(t, sentToOutput) eventIDs = append(eventIDs, ev.EventID()) - wantEventReferences = append(wantEventReferences, ev.EventReference()) // Set the stateSnapshot to 2 for some events to verify they are returned later stateSnapshot := 0 @@ -97,8 +94,8 @@ func Test_EventsTable(t *testing.T) { } wantStateAtEvent = append(wantStateAtEvent, stateAtEvent) wantStateAtEventAndRefs = append(wantStateAtEventAndRefs, types.StateAtEventAndReference{ - StateAtEvent: stateAtEvent, - EventReference: ev.EventReference(), + StateAtEvent: stateAtEvent, + EventID: ev.EventID(), }) } @@ -140,10 +137,6 @@ func Test_EventsTable(t *testing.T) { assert.True(t, ok) } - references, err := tab.BulkSelectEventReference(ctx, nil, nids) - assert.NoError(t, err) - assert.Equal(t, wantEventReferences, references) - stateAndRefs, err := tab.BulkSelectStateAtEventAndReference(ctx, nil, nids) assert.NoError(t, err) assert.Equal(t, wantStateAtEventAndRefs, stateAndRefs) diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 4ce2a9c4e4..445c1223fa 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -2,10 +2,12 @@ package tables import ( "context" + "crypto/ed25519" "database/sql" "errors" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/roomserver/types" @@ -41,7 +43,7 @@ type Events interface { InsertEvent( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, eventID string, - referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, + authEventNIDs []types.EventNID, depth int64, isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[types.StateSnapshotNID][]string, error) @@ -58,7 +60,6 @@ type Events interface { UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error SelectEventID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) - BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) // BulkSelectEventID returns a map from numeric event ID to string event ID. BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. @@ -94,7 +95,7 @@ type StateSnapshot interface { BulkSelectMembershipForHistoryVisibility( ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, - ) (map[string]*gomatrixserverlib.HeaderedEvent, error) + ) (map[string]*types.HeaderedEvent, error) } type StateBlock interface { @@ -112,10 +113,10 @@ type RoomAliases interface { } type PreviousEvents interface { - InsertPreviousEvent(ctx context.Context, txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error + InsertPreviousEvent(ctx context.Context, txn *sql.Tx, previousEventID string, eventNID types.EventNID) error // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. - SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error + SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string) error } type Invites interface { @@ -147,7 +148,7 @@ type Membership interface { SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) - SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) + SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error) } @@ -184,6 +185,21 @@ type Purge interface { ) error } +type UserRoomKeys interface { + // InsertUserRoomPrivatePublicKey inserts the given private key as well as the public key for it. This should be used + // when creating keys locally. + InsertUserRoomPrivatePublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) + // InsertUserRoomPublicKey inserts the given public key, this should be used for users NOT local to this server + InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) + // SelectUserRoomPrivateKey selects the private key for the given user and room combination + SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) + // SelectUserRoomPublicKey selects the public key for the given user and room combination + SelectUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PublicKey, error) + // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. + // If a senderKey can't be found, it is omitted in the result. + BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) +} + // StrippedEvent represents a stripped event for returning extracted content values. type StrippedEvent struct { RoomID string @@ -195,21 +211,21 @@ type StrippedEvent struct { // ExtractContentValue from the given state event. For example, given an m.room.name event with: // content: { name: "Foo" } // this returns "Foo". -func ExtractContentValue(ev *gomatrixserverlib.HeaderedEvent) string { +func ExtractContentValue(ev *types.HeaderedEvent) string { content := ev.Content() key := "" switch ev.Type() { - case gomatrixserverlib.MRoomCreate: + case spec.MRoomCreate: key = "creator" - case gomatrixserverlib.MRoomCanonicalAlias: + case spec.MRoomCanonicalAlias: key = "alias" - case gomatrixserverlib.MRoomHistoryVisibility: + case spec.MRoomHistoryVisibility: key = "history_visibility" - case gomatrixserverlib.MRoomJoinRules: + case spec.MRoomJoinRules: key = "join_rule" - case gomatrixserverlib.MRoomMember: + case spec.MRoomMember: key = "membership" - case gomatrixserverlib.MRoomName: + case spec.MRoomName: key = "name" case "m.room.avatar": key = "url" diff --git a/roomserver/storage/tables/previous_events_table_test.go b/roomserver/storage/tables/previous_events_table_test.go index 63d540696a..9d41e90be5 100644 --- a/roomserver/storage/tables/previous_events_table_test.go +++ b/roomserver/storage/tables/previous_events_table_test.go @@ -45,17 +45,17 @@ func TestPreviousEventsTable(t *testing.T) { defer close() for _, x := range room.Events() { - for _, prevEvent := range x.PrevEvents() { - err := tab.InsertPreviousEvent(ctx, nil, prevEvent.EventID, prevEvent.EventSHA256, 1) + for _, eventID := range x.PrevEventIDs() { + err := tab.InsertPreviousEvent(ctx, nil, eventID, 1) assert.NoError(t, err) - err = tab.SelectPreviousEventExists(ctx, nil, prevEvent.EventID, prevEvent.EventSHA256) + err = tab.SelectPreviousEventExists(ctx, nil, eventID) assert.NoError(t, err) } } - // RandomString with a correct EventSHA256 should fail and return sql.ErrNoRows - err := tab.SelectPreviousEventExists(ctx, nil, util.RandomString(16), room.Events()[0].EventReference().EventSHA256) + // RandomString should fail and return sql.ErrNoRows + err := tab.SelectPreviousEventExists(ctx, nil, util.RandomString(16)) assert.Error(t, err) }) } diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go new file mode 100644 index 0000000000..2809771b4c --- /dev/null +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -0,0 +1,123 @@ +package tables_test + +import ( + "context" + "crypto/ed25519" + "database/sql" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/stretchr/testify/assert" + ed255192 "golang.org/x/crypto/ed25519" +) + +func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.UserRoomKeys, db *sql.DB, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareUserRoomKeysTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateUserRoomKeysTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareUserRoomKeysTable(db) + } + assert.NoError(t, err) + + return tab, db, close +} + +func TestUserRoomKeysTable(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := mustCreateUserRoomKeysTable(t, dbType) + defer close() + userNID := types.EventStateKeyNID(1) + roomNID := types.RoomNID(1) + _, key, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + + err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + var gotKey, key2, key3 ed25519.PrivateKey + var pubKey ed25519.PublicKey + gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + // again, this shouldn't result in an error, but return the existing key + _, key2, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key2) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) + + // add another user + _, key3, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + userNID2 := types.EventStateKeyNID(2) + _, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID2, roomNID, key3) + assert.NoError(t, err) + + gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID) + assert.NoError(t, err) + assert.Equal(t, key, gotKey) + pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, roomNID) + assert.NoError(t, err) + assert.Equal(t, key.Public(), pubKey) + + // try to update an existing key, this should only be done for users NOT on this homeserver + var gotPubKey ed25519.PublicKey + gotPubKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, roomNID, key2.Public().(ed25519.PublicKey)) + assert.NoError(t, err) + assert.Equal(t, key2.Public(), gotPubKey) + + // Key doesn't exist + gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2) + assert.NoError(t, err) + assert.Nil(t, gotKey) + pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, 2) + assert.NoError(t, err) + assert.Nil(t, pubKey) + + // query user NIDs for senderKeys + var gotKeys map[string]types.UserRoomKeyPair + query := map[types.RoomNID][]ed25519.PublicKey{ + roomNID: {key2.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, + types.RoomNID(2): {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, // doesn't exist + } + gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, query) + assert.NoError(t, err) + assert.NotNil(t, gotKeys) + + wantKeys := map[string]types.UserRoomKeyPair{ + string(spec.Base64Bytes(key2.Public().(ed25519.PublicKey)).Encode()): {RoomNID: roomNID, EventStateKeyNID: userNID}, + string(spec.Base64Bytes(key3.Public().(ed25519.PublicKey)).Encode()): {RoomNID: roomNID, EventStateKeyNID: userNID2}, + } + assert.Equal(t, wantKeys, gotKeys) + + // insert key that came in over federation + var gotPublicKey, key4 ed255192.PublicKey + key4, _, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotPublicKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, 2, key4) + assert.NoError(t, err) + assert.Equal(t, key4, gotPublicKey) + + return nil + }) + assert.NoError(t, err) + + }) +} diff --git a/roomserver/types/headered_event.go b/roomserver/types/headered_event.go new file mode 100644 index 0000000000..7839998222 --- /dev/null +++ b/roomserver/types/headered_event.go @@ -0,0 +1,62 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "unsafe" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" +) + +// HeaderedEvent is an Event which serialises to the headered form, which includes +// _room_version and _event_id fields. +type HeaderedEvent struct { + gomatrixserverlib.PDU + Visibility gomatrixserverlib.HistoryVisibility + // TODO: Remove this. This is a temporary workaround to store the userID in the syncAPI. + // It really should be the userKey instead. + UserID spec.UserID + StateKeyResolved *string +} + +func (h *HeaderedEvent) CacheCost() int { + return int(unsafe.Sizeof(*h)) + + len(h.EventID()) + + (cap(h.JSON()) * 2) + + len(h.Version()) + + 1 // redacted bool +} + +func (h *HeaderedEvent) MarshalJSON() ([]byte, error) { + return h.PDU.ToHeaderedJSON() +} + +func (j *HeaderedEvent) UnmarshalJSON(data []byte) error { + ev, err := gomatrixserverlib.NewEventFromHeaderedJSON(data, false) + if err != nil { + return err + } + j.PDU = ev + return nil +} + +func NewEventJSONsFromHeaderedEvents(hes []*HeaderedEvent) gomatrixserverlib.EventJSONs { + result := make(gomatrixserverlib.EventJSONs, len(hes)) + for i := range hes { + result[i] = hes[i].JSON() + } + return result +} diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 6401a94bee..45a3e25fcc 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -17,6 +17,7 @@ package types import ( "encoding/json" + "fmt" "sort" "strings" "sync" @@ -43,6 +44,11 @@ type EventMetadata struct { RoomNID RoomNID } +type UserRoomKeyPair struct { + RoomNID RoomNID + EventStateKeyNID EventStateKeyNID +} + // StateSnapshotNID is a numeric ID for the state at an event. type StateSnapshotNID int64 @@ -199,7 +205,7 @@ func (s StateAtEvent) IsStateEvent() bool { // The StateAtEvent is used to construct the current state of the room from the latest events. type StateAtEventAndReference struct { StateAtEvent - gomatrixserverlib.EventReference + EventID string } type StateAtEventAndReferences []StateAtEventAndReference @@ -228,7 +234,7 @@ func (s StateAtEventAndReferences) EventIDs() string { // It is when performing bulk event lookup in the database. type Event struct { EventNID EventNID - *gomatrixserverlib.Event + gomatrixserverlib.PDU } const ( @@ -328,3 +334,5 @@ func (r *RoomInfo) CopyFrom(r2 *RoomInfo) { r.stateSnapshotNID = r2.stateSnapshotNID r.isStub = r2.isStub } + +var ErrorInvalidRoomInfo = fmt.Errorf("room info is invalid") diff --git a/roomserver/version/version.go b/roomserver/version/version.go index c40d8e0f77..270d428972 100644 --- a/roomserver/version/version.go +++ b/roomserver/version/version.go @@ -28,39 +28,32 @@ func DefaultRoomVersion() gomatrixserverlib.RoomVersion { // RoomVersions returns a map of all known room versions to this // server. -func RoomVersions() map[gomatrixserverlib.RoomVersion]gomatrixserverlib.RoomVersionDescription { +func RoomVersions() map[gomatrixserverlib.RoomVersion]gomatrixserverlib.IRoomVersion { return gomatrixserverlib.RoomVersions() } // SupportedRoomVersions returns a map of descriptions for room // versions that are supported by this homeserver. -func SupportedRoomVersions() map[gomatrixserverlib.RoomVersion]gomatrixserverlib.RoomVersionDescription { - return gomatrixserverlib.SupportedRoomVersions() +func SupportedRoomVersions() map[gomatrixserverlib.RoomVersion]gomatrixserverlib.IRoomVersion { + return gomatrixserverlib.RoomVersions() } // RoomVersion returns information about a specific room version. // An UnknownVersionError is returned if the version is not known // to the server. -func RoomVersion(version gomatrixserverlib.RoomVersion) (gomatrixserverlib.RoomVersionDescription, error) { +func RoomVersion(version gomatrixserverlib.RoomVersion) (gomatrixserverlib.IRoomVersion, error) { if version, ok := gomatrixserverlib.RoomVersions()[version]; ok { return version, nil } - return gomatrixserverlib.RoomVersionDescription{}, UnknownVersionError{version} + return nil, UnknownVersionError{version} } // SupportedRoomVersion returns information about a specific room // version. An UnknownVersionError is returned if the version is not // known to the server, or an UnsupportedVersionError is returned if // the version is known but specifically marked as unsupported. -func SupportedRoomVersion(version gomatrixserverlib.RoomVersion) (gomatrixserverlib.RoomVersionDescription, error) { - result, err := RoomVersion(version) - if err != nil { - return gomatrixserverlib.RoomVersionDescription{}, err - } - if !result.Supported { - return gomatrixserverlib.RoomVersionDescription{}, UnsupportedVersionError{version} - } - return result, nil +func SupportedRoomVersion(version gomatrixserverlib.RoomVersion) (gomatrixserverlib.IRoomVersion, error) { + return RoomVersion(version) } // UnknownVersionError is caused when the room version is not known. diff --git a/setup/base/base.go b/setup/base/base.go index d6c3501098..ea342054cb 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -74,7 +74,7 @@ func CreateClient(cfg *config.Dendrite, dnsCache *fclient.DNSCache) *fclient.Cli // CreateFederationClient creates a new federation client. Should only be called // once per component. -func CreateFederationClient(cfg *config.Dendrite, dnsCache *fclient.DNSCache) *fclient.FederationClient { +func CreateFederationClient(cfg *config.Dendrite, dnsCache *fclient.DNSCache) fclient.FederationClient { identities := cfg.Global.SigningIdentities() if cfg.Global.DisableFederation { return fclient.NewFederationClient( @@ -85,6 +85,7 @@ func CreateFederationClient(cfg *config.Dendrite, dnsCache *fclient.DNSCache) *f fclient.WithTimeout(time.Minute * 5), fclient.WithSkipVerify(cfg.FederationAPI.DisableTLSValidation), fclient.WithKeepAlives(!cfg.FederationAPI.DisableHTTPKeepalives), + fclient.WithUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())), } if cfg.Global.DNSCache.Enabled { opts = append(opts, fclient.WithDNSCache(dnsCache)) @@ -92,7 +93,6 @@ func CreateFederationClient(cfg *config.Dendrite, dnsCache *fclient.DNSCache) *f client := fclient.NewFederationClient( identities, opts..., ) - client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())) return client } diff --git a/setup/config/config.go b/setup/config/config.go index 67106fb1ce..1aa674f035 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "golang.org/x/crypto/ed25519" "gopkg.in/yaml.v2" @@ -239,7 +240,7 @@ func loadConfig( key.KeyID = keyID key.PrivateKey = privateKey - key.PublicKey = gomatrixserverlib.Base64Bytes(privateKey.Public().(ed25519.PublicKey)) + key.PublicKey = spec.Base64Bytes(privateKey.Public().(ed25519.PublicKey)) case key.KeyID == "": return nil, fmt.Errorf("'key_id' must be specified if 'public_key' is specified") diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 34c225a56a..0a602cea68 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -16,6 +16,10 @@ type ClientAPI struct { // secrets) RegistrationDisabled bool `yaml:"registration_disabled"` + // If set, requires users to submit a token during registration. + // Tokens can be managed using admin API. + RegistrationRequiresToken bool `yaml:"registration_requires_token"` + // Enable registration without captcha verification or shared secret. // This option is populated by the -really-enable-open-registration // command line parameter as it is not recommended. @@ -90,6 +94,7 @@ type Ldap struct { func (c *ClientAPI) Defaults(_ DefaultOpts) { c.RegistrationSharedSecret = "" + c.RegistrationRequiresToken = false c.RecaptchaPublicKey = "" c.RecaptchaPrivateKey = "" c.RecaptchaEnabled = false diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 8c1540b579..a72eee369a 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -2,6 +2,7 @@ package config import ( "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type FederationAPI struct { @@ -101,7 +102,7 @@ type KeyPerspectives []KeyPerspective type KeyPerspective struct { // The server name of the perspective key server - ServerName gomatrixserverlib.ServerName `yaml:"server_name"` + ServerName spec.ServerName `yaml:"server_name"` // Server keys for the perspective user, used to verify the // keys have been signed by the perspective server Keys []KeyPerspectiveTrustKey `yaml:"keys"` diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 0687e9d351..1622bf3576 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -9,6 +9,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "golang.org/x/crypto/ed25519" ) @@ -122,7 +123,7 @@ func (c *Global) Verify(configErrs *ConfigErrors) { c.Cache.Verify(configErrs) } -func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool { +func (c *Global) IsLocalServerName(serverName spec.ServerName) bool { if c.ServerName == serverName { return true } @@ -134,7 +135,7 @@ func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool return false } -func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib.ServerName, error) { +func (c *Global) SplitLocalID(sigil byte, id string) (string, spec.ServerName, error) { u, s, err := gomatrixserverlib.SplitID(sigil, id) if err != nil { return u, s, err @@ -145,7 +146,7 @@ func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib. return u, s, nil } -func (c *Global) VirtualHost(serverName gomatrixserverlib.ServerName) *VirtualHost { +func (c *Global) VirtualHost(serverName spec.ServerName) *VirtualHost { for _, v := range c.VirtualHosts { if v.ServerName == serverName { return v @@ -154,7 +155,7 @@ func (c *Global) VirtualHost(serverName gomatrixserverlib.ServerName) *VirtualHo return nil } -func (c *Global) VirtualHostForHTTPHost(serverName gomatrixserverlib.ServerName) *VirtualHost { +func (c *Global) VirtualHostForHTTPHost(serverName spec.ServerName) *VirtualHost { for _, v := range c.VirtualHosts { if v.ServerName == serverName { return v @@ -168,7 +169,7 @@ func (c *Global) VirtualHostForHTTPHost(serverName gomatrixserverlib.ServerName) return nil } -func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*fclient.SigningIdentity, error) { +func (c *Global) SigningIdentityFor(serverName spec.ServerName) (*fclient.SigningIdentity, error) { for _, id := range c.SigningIdentities() { if id.ServerName == serverName { return id, nil @@ -205,7 +206,7 @@ type VirtualHost struct { // Match these HTTP Host headers on the `/key/v2/server` endpoint, this needs // to match all delegated names, likely including the port number too if // the well-known delegation includes that also. - MatchHTTPHosts []gomatrixserverlib.ServerName `yaml:"match_http_hosts"` + MatchHTTPHosts []spec.ServerName `yaml:"match_http_hosts"` // Is registration enabled on this virtual host? AllowRegistration bool `yaml:"allow_registration"` @@ -236,14 +237,14 @@ type OldVerifyKeys struct { PrivateKey ed25519.PrivateKey `yaml:"-"` // The public key, in case only that part is known. - PublicKey gomatrixserverlib.Base64Bytes `yaml:"public_key"` + PublicKey spec.Base64Bytes `yaml:"public_key"` // The key ID of the private key. KeyID gomatrixserverlib.KeyID `yaml:"key_id"` // When the private key was designed as "expired", as a UNIX timestamp // in millisecond precision. - ExpiredAt gomatrixserverlib.Timestamp `yaml:"expired_at"` + ExpiredAt spec.Timestamp `yaml:"expired_at"` } // The configuration to use for Prometheus metrics diff --git a/setup/config/config_test.go b/setup/config/config_test.go index a0509aafbf..8a65c990f8 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -19,8 +19,8 @@ import ( "reflect" "testing" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) @@ -275,7 +275,7 @@ func Test_SigningIdentityFor(t *testing.T) { tests := []struct { name string virtualHosts []*VirtualHost - serverName gomatrixserverlib.ServerName + serverName spec.ServerName want *fclient.SigningIdentity wantErr bool }{ @@ -285,17 +285,17 @@ func Test_SigningIdentityFor(t *testing.T) { }, { name: "no identity found", - serverName: gomatrixserverlib.ServerName("doesnotexist"), + serverName: spec.ServerName("doesnotexist"), wantErr: true, }, { name: "found identity", - serverName: gomatrixserverlib.ServerName("main"), + serverName: spec.ServerName("main"), want: &fclient.SigningIdentity{ServerName: "main"}, }, { name: "identity found on virtual hosts", - serverName: gomatrixserverlib.ServerName("vh2"), + serverName: spec.ServerName("vh2"), virtualHosts: []*VirtualHost{ {SigningIdentity: fclient.SigningIdentity{ServerName: "vh1"}}, {SigningIdentity: fclient.SigningIdentity{ServerName: "vh2"}}, diff --git a/setup/monolith.go b/setup/monolith.go index 5f06529070..5296b551e2 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -41,7 +41,7 @@ type Monolith struct { Config *config.Dendrite KeyRing *gomatrixserverlib.KeyRing Client *fclient.Client - FedClient *fclient.FederationClient + FedClient fclient.FederationClient AppserviceAPI appserviceAPI.AppServiceInternalAPI FederationAPI federationAPI.FederationInternalAPI @@ -73,7 +73,7 @@ func (m *Monolith) AddAllPublicRoutes( m.ExtPublicRoomsProvider, enableMetrics, ) federationapi.AddPublicRoutes( - processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, nil, enableMetrics, + processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics, ) mediaapi.AddPublicRoutes(routers.Media, cm, cfg, m.UserAPI, m.Client) syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics) diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index e1758920b3..f284199050 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -27,17 +27,18 @@ import ( "strings" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/sqlutil" roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/synctypes" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -87,13 +88,15 @@ type EventRelationshipResponse struct { type MSC2836EventRelationshipsResponse struct { fclient.MSC2836EventRelationshipsResponse - ParsedEvents []*gomatrixserverlib.Event - ParsedAuthChain []*gomatrixserverlib.Event + ParsedEvents []gomatrixserverlib.PDU + ParsedAuthChain []gomatrixserverlib.PDU } -func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse { +func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: synctypes.ToClientEvents(res.ParsedEvents, synctypes.FormatAll), + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), Limited: res.Limited, NextBatch: res.NextBatch, } @@ -111,7 +114,7 @@ func Enable( } hooks.Enable() hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { - he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) + he := headeredEvent.(*types.HeaderedEvent) hookErr := db.StoreRelation(context.Background(), he) if hookErr != nil { util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error( @@ -134,7 +137,7 @@ func Enable( routers.Federation.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( "msc2836_event_relationships", func(req *http.Request) util.JSONResponse { - fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + fedReq, errResp := fclient.VerifyHTTPRequest( req, time.Now(), cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing, ) if fedReq == nil { @@ -151,12 +154,12 @@ type reqCtx struct { rsAPI roomserver.RoomserverInternalAPI db Database req *EventRelationshipRequest - userID string + userID spec.UserID roomVersion gomatrixserverlib.RoomVersion // federated request args isFederatedRequest bool - serverName gomatrixserverlib.ServerName + serverName spec.ServerName fsAPI fs.FederationInternalAPI } @@ -167,13 +170,20 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP util.GetLogger(req.Context()).WithError(err).Error("failed to decode HTTP request as JSON") return util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), + JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), + } + } + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: 400, + JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), } } rc := reqCtx{ ctx: req.Context(), req: relation, - userID: device.UserID, + userID: *userID, rsAPI: rsAPI, fsAPI: fsAPI, isFederatedRequest: false, @@ -186,20 +196,20 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP return util.JSONResponse{ Code: 200, - JSON: toClientResponse(res), + JSON: toClientResponse(req.Context(), res, rsAPI), } } } func federatedEventRelationship( - ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, + ctx context.Context, fedReq *fclient.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, ) util.JSONResponse { relation, err := NewEventRelationshipRequest(bytes.NewBuffer(fedReq.Content())) if err != nil { util.GetLogger(ctx).WithError(err).Error("failed to decode HTTP request as JSON") return util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), + JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)), } } rc := reqCtx{ @@ -254,7 +264,7 @@ func federatedEventRelationship( func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONResponse) { var res MSC2836EventRelationshipsResponse - var returnEvents []*gomatrixserverlib.HeaderedEvent + var returnEvents []*types.HeaderedEvent // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. event := rc.getLocalEvent(rc.req.RoomID, rc.req.EventID) if event == nil { @@ -266,7 +276,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo if event == nil || !rc.authorisedToSeeEvent(event) { return nil, &util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), + JSON: spec.Forbidden("Event does not exist or you are not authorised to see it"), } } rc.roomVersion = event.Version() @@ -298,17 +308,17 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo for _, ev := range returnEvents { included[ev.EventID()] = true } - var events []*gomatrixserverlib.HeaderedEvent + var events []*types.HeaderedEvent events, walkLimited = walkThread( rc.ctx, rc.db, rc, included, remaining, ) returnEvents = append(returnEvents, events...) } - res.ParsedEvents = make([]*gomatrixserverlib.Event, len(returnEvents)) + res.ParsedEvents = make([]gomatrixserverlib.PDU, len(returnEvents)) for i, ev := range returnEvents { // for each event, extract the children_count | hash and add it as unsigned data. rc.addChildMetadata(ev) - res.ParsedEvents[i] = ev.Unwrap() + res.ParsedEvents[i] = ev.PDU } res.Limited = remaining == 0 || walkLimited return &res, nil @@ -317,17 +327,15 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo // fetchUnknownEvent retrieves an unknown event from the room specified. This server must // be joined to the room in question. This has the side effect of injecting surround threaded // events into the roomserver. -func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.HeaderedEvent { +func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *types.HeaderedEvent { if rc.isFederatedRequest || roomID == "" { // we don't do fed hits for fed requests, and we can't ask servers without a room ID! return nil } logger := util.GetLogger(rc.ctx).WithField("room_id", roomID) // if they supplied a room_id, check the room exists. - var queryVerRes roomserver.QueryRoomVersionForRoomResponse - err := rc.rsAPI.QueryRoomVersionForRoom(rc.ctx, &roomserver.QueryRoomVersionForRoomRequest{ - RoomID: roomID, - }, &queryVerRes) + + roomVersion, err := rc.rsAPI.QueryRoomVersionForRoom(rc.ctx, roomID) if err != nil { logger.WithError(err).Warn("failed to query room version for room, does this room exist?") return nil @@ -366,14 +374,14 @@ func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.H // Inject the response into the roomserver to remember the event across multiple calls and to set // unexplored flags correctly. for _, srv := range serversToQuery { - res, err := rc.MSC2836EventRelationships(eventID, srv, queryVerRes.RoomVersion) + res, err := rc.MSC2836EventRelationships(eventID, srv, roomVersion) if err != nil { continue } rc.injectResponseToRoomserver(res) for _, ev := range res.ParsedEvents { if ev.EventID() == eventID { - return ev.Headered(ev.Version()) + return &types.HeaderedEvent{PDU: ev} } } } @@ -383,7 +391,7 @@ func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.H // If include_parent: true and there is a valid m.relationship field in the event, // retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array. -func (rc *reqCtx) includeParent(childEvent *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) { +func (rc *reqCtx) includeParent(childEvent *types.HeaderedEvent) (parent *types.HeaderedEvent) { parentID, _, _ := parentChildEventIDs(childEvent) if parentID == "" { return nil @@ -394,7 +402,7 @@ func (rc *reqCtx) includeParent(childEvent *gomatrixserverlib.HeaderedEvent) (pa // If include_children: true, lookup all events which have event_id as an m.relationship // Apply history visibility checks to all these events and add the ones which pass into the response array, // honouring the recent_first flag and the limit. -func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { +func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*types.HeaderedEvent, *util.JSONResponse) { if rc.hasUnexploredChildren(parentID) { // we need to do a remote request to pull in the children as we are missing them locally. serversToQuery := rc.getServersForEventID(parentID) @@ -428,10 +436,12 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent") - resErr := jsonerror.InternalServerError() - return nil, &resErr + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - var childEvents []*gomatrixserverlib.HeaderedEvent + var childEvents []*types.HeaderedEvent for _, child := range children { childEvent := rc.lookForEvent(child.EventID) if childEvent != nil { @@ -448,8 +458,8 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen // honouring the limit, max_depth and max_breadth values according to the following rules func walkThread( ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, -) ([]*gomatrixserverlib.HeaderedEvent, bool) { - var result []*gomatrixserverlib.HeaderedEvent +) ([]*types.HeaderedEvent, bool) { + var result []*types.HeaderedEvent eventWalker := walker{ ctx: ctx, req: rc.req, @@ -486,7 +496,7 @@ func walkThread( } // MSC2836EventRelationships performs an /event_relationships request to a remote server -func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { +func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv spec.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, fclient.MSC2836EventRelationshipsRequest{ EventID: eventID, DepthFirst: rc.req.DepthFirst, @@ -511,7 +521,7 @@ func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverli // authorisedToSeeEvent checks that the user or server is allowed to see this event. Returns true if allowed to // see this request. This only needs to be done once per room at present as we just check for joined status. -func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) bool { +func (rc *reqCtx) authorisedToSeeEvent(event *types.HeaderedEvent) bool { if rc.isFederatedRequest { // make sure the server is in this room var res fs.QueryJoinedHostServerNamesInRoomResponse @@ -545,7 +555,7 @@ func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) b return queryMembershipRes.IsInRoom } -func (rc *reqCtx) getServersForEventID(eventID string) []gomatrixserverlib.ServerName { +func (rc *reqCtx) getServersForEventID(eventID string) []spec.ServerName { if rc.req.RoomID == "" { util.GetLogger(rc.ctx).WithField("event_id", eventID).Error( "getServersForEventID: event exists in unknown room", @@ -594,7 +604,7 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelation // lookForEvent returns the event for the event ID given, by trying to query remote servers // if the event ID is unknown via /event_relationships. -func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent { +func (rc *reqCtx) lookForEvent(eventID string) *types.HeaderedEvent { event := rc.getLocalEvent(rc.req.RoomID, eventID) if event == nil { queryRes := rc.remoteEventRelationships(eventID) @@ -603,7 +613,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent rc.injectResponseToRoomserver(queryRes) for _, ev := range queryRes.ParsedEvents { if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() { - return ev.Headered(ev.Version()) + return &types.HeaderedEvent{PDU: ev} } } } @@ -625,7 +635,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent return nil } -func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.HeaderedEvent { +func (rc *reqCtx) getLocalEvent(roomID, eventID string) *types.HeaderedEvent { var queryEventsRes roomserver.QueryEventsByIDResponse err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{ RoomID: roomID, @@ -646,7 +656,7 @@ func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.Heade // into the roomserver as KindOutlier, with auth chains. func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsResponse) { var stateEvents gomatrixserverlib.EventJSONs - var messageEvents []*gomatrixserverlib.Event + var messageEvents []gomatrixserverlib.PDU for _, ev := range res.ParsedEvents { if ev.StateKey() != nil { stateEvents = append(stateEvents, ev.JSON()) @@ -654,18 +664,18 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo messageEvents = append(messageEvents, ev) } } - respState := fclient.RespState{ + respState := &fclient.RespState{ AuthEvents: res.AuthChain, StateEvents: stateEvents, } - eventsInOrder := respState.Events(rc.roomVersion) + eventsInOrder := gomatrixserverlib.LineariseStateResponse(rc.roomVersion, respState) // everything gets sent as an outlier because auth chain events may be disjoint from the DAG // as may the threaded events. var ires []roomserver.InputRoomEvent for _, outlier := range append(eventsInOrder, messageEvents...) { ires = append(ires, roomserver.InputRoomEvent{ Kind: roomserver.KindOutlier, - Event: outlier.Headered(outlier.Version()), + Event: &types.HeaderedEvent{PDU: outlier}, }) } // we've got the data by this point so use a background context @@ -684,12 +694,12 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo } } -func (rc *reqCtx) addChildMetadata(ev *gomatrixserverlib.HeaderedEvent) { +func (rc *reqCtx) addChildMetadata(ev *types.HeaderedEvent) { count, hash := rc.getChildMetadata(ev.EventID()) if count == 0 { return } - err := ev.SetUnsignedField("children_hash", gomatrixserverlib.Base64Bytes(hash)) + err := ev.SetUnsignedField("children_hash", spec.Base64Bytes(hash)) if err != nil { util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children_hash") } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 3c4431489f..16fb3efe14 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -18,11 +18,13 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/sqlutil" roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/mscs/msc2836" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -170,7 +172,7 @@ func TestMSC2836(t *testing.T) { bob: {roomID}, charlie: {roomID}, }, - events: map[string]*gomatrixserverlib.HeaderedEvent{ + events: map[string]*types.HeaderedEvent{ eventA.EventID(): eventA, eventB.EventID(): eventB, eventC.EventID(): eventC, @@ -181,7 +183,7 @@ func TestMSC2836(t *testing.T) { eventH.EventID(): eventH, }, } - router := injectEvents(t, nopUserAPI, nopRsAPI, []*gomatrixserverlib.HeaderedEvent{ + router := injectEvents(t, nopUserAPI, nopRsAPI, []*types.HeaderedEvent{ eventA, eventB, eventC, eventD, eventE, eventF, eventG, eventH, }) cancel := runServer(t, router) @@ -395,7 +397,7 @@ func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2836.EventRelatio func runServer(t *testing.T, router *mux.Router) func() { t.Helper() externalServ := &http.Server{ - Addr: string(":8009"), + Addr: string("127.0.0.1:8009"), WriteTimeout: 60 * time.Second, Handler: router, } @@ -520,7 +522,15 @@ type testRoomserverAPI struct { // We'll override the functions we care about. roomserver.RoomserverInternalAPI userToJoinedRooms map[string][]string - events map[string]*gomatrixserverlib.HeaderedEvent + events map[string]*types.HeaderedEvent +} + +func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + +func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + return spec.SenderID(userID.String()), nil } func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { @@ -534,7 +544,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver } func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error { - rooms := r.userToJoinedRooms[req.UserID] + rooms := r.userToJoinedRooms[req.UserID.String()] for _, roomID := range rooms { if roomID == req.RoomID { res.IsInRoom = true @@ -546,7 +556,7 @@ func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roo return nil } -func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router { +func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*types.HeaderedEvent) *mux.Router { t.Helper() cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ @@ -578,28 +588,28 @@ type fledglingEvent struct { RoomID string } -func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { +func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEvent) { t.Helper() roomVer := gomatrixserverlib.RoomVersionV6 seed := make([]byte, ed25519.SeedSize) // zero seed key := ed25519.NewKeyFromSeed(seed) - eb := gomatrixserverlib.EventBuilder{ - Sender: ev.Sender, + eb := gomatrixserverlib.MustGetRoomVersion(roomVer).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ + SenderID: ev.Sender, Depth: 999, Type: ev.Type, StateKey: ev.StateKey, RoomID: ev.RoomID, - } + }) err := eb.SetContent(ev.Content) if err != nil { t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) } // make sure the origin_server_ts changes so we can test recency time.Sleep(1 * time.Millisecond) - signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer) + signedEvent, err := eb.Build(time.Now(), spec.ServerName("localhost"), "ed25519:test", key) if err != nil { t.Fatalf("mustCreateEvent: failed to sign event: %s", err) } - h := signedEvent.Headered(roomVer) + h := &types.HeaderedEvent{PDU: signedEvent} return h } diff --git a/setup/mscs/msc2836/storage.go b/setup/mscs/msc2836/storage.go index 1cf7e87856..6a45f08a4e 100644 --- a/setup/mscs/msc2836/storage.go +++ b/setup/mscs/msc2836/storage.go @@ -8,21 +8,22 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) type eventInfo struct { EventID string - OriginServerTS gomatrixserverlib.Timestamp + OriginServerTS spec.Timestamp RoomID string } type Database interface { // StoreRelation stores the parent->child and child->parent relationship for later querying. // Also stores the event metadata e.g timestamp - StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error + StoreRelation(ctx context.Context, ev *types.HeaderedEvent) error // ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the // provided `relType`. The returned slice is sorted by origin_server_ts according to whether // `recentFirst` is true or false. @@ -34,7 +35,7 @@ type Database interface { // UpdateChildMetadata persists the children_count and children_hash from this event if and only if // the count is greater than what was previously there. If the count is updated, the event will be // updated to be unexplored. - UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error + UpdateChildMetadata(ctx context.Context, ev *types.HeaderedEvent) error // ChildMetadata returns the children_count and children_hash for the event ID in question. // Also returns the `explored` flag, which is set to true when MarkChildrenExplored is called and is set // back to `false` when a larger count is inserted via UpdateChildMetadata. @@ -221,7 +222,7 @@ func newSQLiteDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOption return &d, nil } -func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { +func (p *DB) StoreRelation(ctx context.Context, ev *types.HeaderedEvent) error { parent, child, relType := parentChildEventIDs(ev) if parent == "" || child == "" { return nil @@ -243,7 +244,7 @@ func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEv }) } -func (p *DB) UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { +func (p *DB) UpdateChildMetadata(ctx context.Context, ev *types.HeaderedEvent) error { eventCount, eventHash := extractChildMetadata(ev) if eventCount == 0 { return nil // nothing to update with @@ -314,7 +315,7 @@ func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*even return &ei, nil } -func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) { +func parentChildEventIDs(ev *types.HeaderedEvent) (parent, child, relType string) { if ev == nil { return } @@ -333,7 +334,7 @@ func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, re return body.Relationship.EventID, ev.EventID(), body.Relationship.RelType } -func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, servers []string) { +func roomIDAndServers(ev *types.HeaderedEvent) (roomID string, servers []string) { servers = []string{} if ev == nil { return @@ -348,10 +349,10 @@ func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, serve return body.RoomID, body.Servers } -func extractChildMetadata(ev *gomatrixserverlib.HeaderedEvent) (count int, hash []byte) { +func extractChildMetadata(ev *types.HeaderedEvent) (count int, hash []byte) { unsigned := struct { - Counts map[string]int `json:"children"` - Hash gomatrixserverlib.Base64Bytes `json:"children_hash"` + Counts map[string]int `json:"children"` + Hash spec.Base64Bytes `json:"children_hash"` }{} if err := json.Unmarshal(ev.Unsigned(), &unsigned); err != nil { // expected if there is no unsigned field at all diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 965af92071..3e5ffda925 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -28,15 +28,16 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -64,7 +65,7 @@ func Enable( fedAPI := httputil.MakeExternalAPI( "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { - fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + fedReq, errResp := fclient.VerifyHTTPRequest( req, time.Now(), cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing, ) if fedReq == nil { @@ -85,16 +86,16 @@ func Enable( } func federatedSpacesHandler( - ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, roomID string, + ctx context.Context, fedReq *fclient.FederationRequest, roomID string, cache caching.SpaceSummaryRoomsCache, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, - thisServer gomatrixserverlib.ServerName, + thisServer spec.ServerName, ) util.JSONResponse { u, err := url.Parse(fedReq.RequestURI()) if err != nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidParam("bad request uri"), + JSON: spec.InvalidParam("bad request uri"), } } @@ -122,7 +123,7 @@ func spacesHandler( rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, cache caching.SpaceSummaryRoomsCache, - thisServer gomatrixserverlib.ServerName, + thisServer spec.ServerName, ) func(*http.Request, *userapi.Device) util.JSONResponse { // declared outside the returned handler so it persists between calls // TODO: clear based on... time? @@ -162,8 +163,8 @@ type paginationInfo struct { type walker struct { rootRoomID string caller *userapi.Device - serverName gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName + serverName spec.ServerName + thisServer spec.ServerName rsAPI roomserver.RoomserverInternalAPI fsAPI fs.FederationInternalAPI ctx context.Context @@ -212,13 +213,13 @@ func (w *walker) walk() util.JSONResponse { // CS API format return util.JSONResponse{ Code: 403, - JSON: jsonerror.Forbidden("room is unknown/forbidden"), + JSON: spec.Forbidden("room is unknown/forbidden"), } } else { // SS API format return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("room is unknown/forbidden"), + JSON: spec.NotFound("room is unknown/forbidden"), } } } @@ -231,7 +232,7 @@ func (w *walker) walk() util.JSONResponse { if cache == nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("invalid from"), + JSON: spec.InvalidParam("invalid from"), } } } else { @@ -269,7 +270,7 @@ func (w *walker) walk() util.JSONResponse { // if this room is not a space room, skip. var roomType string - create := w.stateEvent(rv.roomID, gomatrixserverlib.MRoomCreate, "") + create := w.stateEvent(rv.roomID, spec.MRoomCreate, "") if create != nil { // escape the `.`s so gjson doesn't think it's nested roomType = gjson.GetBytes(create.Content(), strings.ReplaceAll(ConstCreateEventContentKey, ".", `\.`)).Str @@ -375,7 +376,7 @@ func (w *walker) walk() util.JSONResponse { if len(discoveredRooms) == 0 { return util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound("room is unknown/forbidden"), + JSON: spec.NotFound("room is unknown/forbidden"), } } return util.JSONResponse{ @@ -387,7 +388,7 @@ func (w *walker) walk() util.JSONResponse { } } -func (w *walker) stateEvent(roomID, evType, stateKey string) *gomatrixserverlib.HeaderedEvent { +func (w *walker) stateEvent(roomID, evType, stateKey string) *types.HeaderedEvent { var queryRes roomserver.QueryCurrentStateResponse tuple := gomatrixserverlib.StateKeyTuple{ EventType: evType, @@ -434,7 +435,7 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) *fclient.MSC294 if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, w.thisServer, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) + res, err := w.fsAPI.MSC2946Spaces(ctx, w.thisServer, spec.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue @@ -484,11 +485,11 @@ func (w *walker) authorised(roomID, parentRoomID string) (authed, isJoinedOrInvi func (w *walker) authorisedServer(roomID string) bool { // Check history visibility / join rules first hisVisTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomHistoryVisibility, + EventType: spec.MRoomHistoryVisibility, StateKey: "", } joinRuleTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomJoinRules, + EventType: spec.MRoomJoinRules, StateKey: "", } var queryRoomRes roomserver.QueryCurrentStateResponse @@ -522,11 +523,11 @@ func (w *walker) authorisedServer(roomID string) bool { return false } - if rule == gomatrixserverlib.Public || rule == gomatrixserverlib.Knock { + if rule == spec.Public || rule == spec.Knock { return true } - if rule == gomatrixserverlib.Restricted { + if rule == spec.Restricted { allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership")...) } } @@ -556,15 +557,15 @@ func (w *walker) authorisedServer(roomID string) bool { // Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true. func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoinedOrInvited bool) { hisVisTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomHistoryVisibility, + EventType: spec.MRoomHistoryVisibility, StateKey: "", } joinRuleTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomJoinRules, + EventType: spec.MRoomJoinRules, StateKey: "", } roomMemberTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: w.caller.UserID, } var queryRes roomserver.QueryCurrentStateResponse @@ -581,7 +582,7 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi memberEv := queryRes.StateEvents[roomMemberTuple] if memberEv != nil { membership, _ := memberEv.Membership() - if membership == gomatrixserverlib.Join || membership == gomatrixserverlib.Invite { + if membership == spec.Join || membership == spec.Invite { return true, true } } @@ -598,9 +599,9 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi rule, ruleErr := joinRuleEv.JoinRule() if ruleErr != nil { util.GetLogger(w.ctx).WithError(ruleErr).WithField("parent_room_id", parentRoomID).Warn("failed to get join rule") - } else if rule == gomatrixserverlib.Public || rule == gomatrixserverlib.Knock { + } else if rule == spec.Public || rule == spec.Knock { allowed = true - } else if rule == gomatrixserverlib.Restricted { + } else if rule == spec.Restricted { allowedRoomIDs := w.restrictedJoinRuleAllowedRooms(joinRuleEv, "m.room_membership") // check parent is in the allowed set for _, a := range allowedRoomIDs { @@ -625,7 +626,7 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi memberEv = queryRes2.StateEvents[roomMemberTuple] if memberEv != nil { membership, _ := memberEv.Membership() - if membership == gomatrixserverlib.Join { + if membership == spec.Join { return true, false } } @@ -635,9 +636,9 @@ func (w *walker) authorisedUser(roomID, parentRoomID string) (authed bool, isJoi return false, false } -func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.HeaderedEvent, allowType string) (allows []string) { +func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *types.HeaderedEvent, allowType string) (allows []string) { rule, _ := joinRuleEv.JoinRule() - if rule != gomatrixserverlib.Restricted { + if rule != spec.Restricted { return nil } var jrContent gomatrixserverlib.JoinRuleContent @@ -656,7 +657,7 @@ func (w *walker) restrictedJoinRuleAllowedRooms(joinRuleEv *gomatrixserverlib.He // references returns all child references pointing to or from this room. func (w *walker) childReferences(roomID string) ([]fclient.MSC2946StrippedEvent, error) { createTuple := gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomCreate, + EventType: spec.MRoomCreate, StateKey: "", } var res roomserver.QueryCurrentStateResponse @@ -691,7 +692,7 @@ func (w *walker) childReferences(roomID string) ([]fclient.MSC2946StrippedEvent, // else we'll incorrectly walk redacted events (as the link // is in the state_key) if content.Get("via").Exists() { - strip := stripped(ev.Event) + strip := stripped(ev.PDU) if strip == nil { continue } @@ -721,7 +722,7 @@ func (s set) isSet(val string) bool { return ok } -func stripped(ev *gomatrixserverlib.Event) *fclient.MSC2946StrippedEvent { +func stripped(ev gomatrixserverlib.PDU) *fclient.MSC2946StrippedEvent { if ev.StateKey() == nil { return nil } @@ -729,7 +730,7 @@ func stripped(ev *gomatrixserverlib.Event) *fclient.MSC2946StrippedEvent { Type: ev.Type(), StateKey: *ev.StateKey(), Content: ev.Content(), - Sender: ev.Sender(), + Sender: string(ev.SenderID()), OriginServerTS: ev.OriginServerTS(), } } diff --git a/setup/process/process.go b/setup/process/process.go index b2d2844a8a..9a3d6401ca 100644 --- a/setup/process/process.go +++ b/setup/process/process.go @@ -10,7 +10,7 @@ import ( type ProcessContext struct { mu sync.RWMutex - wg *sync.WaitGroup // used to wait for components to shutdown + wg sync.WaitGroup // used to wait for components to shutdown ctx context.Context // cancelled when Stop is called shutdown context.CancelFunc // shut down Dendrite degraded map[string]struct{} // reasons why the process is degraded @@ -21,7 +21,7 @@ func NewProcessContext() *ProcessContext { return &ProcessContext{ ctx: ctx, shutdown: shutdown, - wg: &sync.WaitGroup{}, + wg: sync.WaitGroup{}, } } diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 43dc0f5176..3ed455e9f4 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -21,7 +21,7 @@ import ( "time" "github.com/getsentry/sentry-go" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" @@ -49,7 +49,7 @@ type OutputClientDataConsumer struct { db storage.Database stream streams.StreamProvider notifier *notifier.Notifier - serverName gomatrixserverlib.ServerName + serverName spec.ServerName fts fulltext.Indexer cfg *config.SyncAPI } @@ -121,9 +121,9 @@ func (s *OutputClientDataConsumer) Start() error { switch ev.Type() { case "m.room.message": e.Content = gjson.GetBytes(ev.Content(), "body").String() - case gomatrixserverlib.MRoomName: + case spec.MRoomName: e.Content = gjson.GetBytes(ev.Content(), "name").String() - case gomatrixserverlib.MRoomTopic: + case spec.MRoomTopic: e.Content = gjson.GetBytes(ev.Content(), "topic").String() default: continue diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 6e3150c296..c7c0866ba9 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -26,7 +26,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) @@ -108,7 +108,7 @@ func (s *PresenceConsumer) Start() error { for i := range deviceRes.Devices { if int64(presence.LastActiveTS) < deviceRes.Devices[i].LastSeenTS { - presence.LastActiveTS = gomatrixserverlib.Timestamp(deviceRes.Devices[i].LastSeenTS) + presence.LastActiveTS = spec.Timestamp(deviceRes.Devices[i].LastSeenTS) } } @@ -161,11 +161,11 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool // already checked, so no need to check error p, _ := types.PresenceFromString(presence) - s.EmitPresence(ctx, userID, p, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync) + s.EmitPresence(ctx, userID, p, statusMsg, spec.Timestamp(ts), fromSync) return true } -func (s *PresenceConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts gomatrixserverlib.Timestamp, fromSync bool) { +func (s *PresenceConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts spec.Timestamp, fromSync bool) { pos, err := s.db.UpdatePresence(ctx, userID, presence, statusMsg, ts, fromSync) if err != nil { logrus.WithError(err).WithField("user", userID).WithField("presence", presence).Warn("failed to updated presence for user") diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index e39d43f949..69571f90cf 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -19,7 +19,6 @@ import ( "strconv" "github.com/getsentry/sentry-go" - "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -30,6 +29,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) // OutputReceiptEventConsumer consumes events that originated in the EDU server. @@ -89,7 +89,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return true } - output.Timestamp = gomatrixserverlib.Timestamp(timestamp) + output.Timestamp = spec.Timestamp(timestamp) streamPos, err := s.db.StoreReceipt( s.ctx, diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 21f6104d61..e6b5ddbb0c 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -21,7 +21,7 @@ import ( "fmt" "github.com/getsentry/sentry-go" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -108,7 +109,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms // Ignore redaction events. We will add them to the database when they are // validated (when we receive OutputTypeRedactedEvent) event := output.NewRoomEvent.Event - if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil { + if event.Type() == spec.MRoomRedaction && event.StateKey() == nil { // in the special case where the event redacts itself, just pass the message through because // we will never see the other part of the pair if event.Redacts() != event.EventID() { @@ -150,7 +151,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms func (s *OutputRoomEventConsumer) onRedactEvent( ctx context.Context, msg api.OutputRedactedEvent, ) error { - err := s.db.RedactEvent(ctx, msg.RedactedEventID, msg.RedactedBecause) + err := s.db.RedactEvent(ctx, msg.RedactedEventID, msg.RedactedBecause, s.rsAPI) if err != nil { log.WithError(err).Error("RedactEvent error'd") return err @@ -255,16 +256,19 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } } - pduPos, err := s.db.WriteEvent( - ctx, - ev, - addsStateEvents, - msg.AddsStateEventIDs, - msg.RemovesStateEventIDs, - msg.TransactionID, - false, - msg.HistoryVisibility, - ) + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return err + } + + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, ev.SenderID()) + if err != nil { + return err + } + + ev.UserID = *userID + + pduPos, err := s.db.WriteEvent(ctx, ev, addsStateEvents, msg.AddsStateEventIDs, msg.RemovesStateEventIDs, msg.TransactionID, false, msg.HistoryVisibility) if err != nil { // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ @@ -314,16 +318,19 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( // hack but until we have some better strategy for dealing with // old events in the sync API, this should at least prevent us // from confusing clients into thinking they've joined/left rooms. - pduPos, err := s.db.WriteEvent( - ctx, - ev, - []*gomatrixserverlib.HeaderedEvent{}, - []string{}, // adds no state - []string{}, // removes no state - nil, // no transaction - ev.StateKey() != nil, // exclude from sync?, - msg.HistoryVisibility, - ) + + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return err + } + + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, ev.SenderID()) + if err != nil { + return err + } + ev.UserID = *userID + + pduPos, err := s.db.WriteEvent(ctx, ev, []*rstypes.HeaderedEvent{}, []string{}, []string{}, nil, ev.StateKey() != nil, msg.HistoryVisibility) if err != nil { // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ @@ -361,8 +368,8 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( return nil } -func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, sp types.StreamPosition) (types.StreamPosition, error) { - if ev.Type() != gomatrixserverlib.MRoomMember { +func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rstypes.HeaderedEvent, sp types.StreamPosition) (types.StreamPosition, error) { + if ev.Type() != spec.MRoomMember { return sp, nil } membership, err := ev.Membership() @@ -370,9 +377,21 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom return sp, fmt.Errorf("ev.Membership: %w", err) } // TODO: check that it's a join and not a profile change (means unmarshalling prev_content) - if membership == gomatrixserverlib.Join { + if membership == spec.Join { // check it's a local join - if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil { + if ev.StateKey() == nil { + return sp, fmt.Errorf("unexpected nil state_key") + } + + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return sp, err + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey())) + if err != nil || userID == nil { + return sp, fmt.Errorf("failed getting userID for sender: %w", err) + } + if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) { return sp, nil } @@ -394,9 +413,21 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( if msg.Event.StateKey() == nil { return } - if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil { + + validRoomID, err := spec.NewRoomID(msg.Event.RoomID()) + if err != nil { + return + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*msg.Event.StateKey())) + if err != nil || userID == nil { return } + if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) { + return + } + + msg.Event.UserID = *userID + pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) if err != nil { sentry.CaptureException(err) @@ -433,13 +464,31 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( // Only notify clients about retired invite events, if the user didn't accept the invite. // The PDU stream will also receive an event about accepting the invitation, so there should // be a "smooth" transition from invite -> join, and not invite -> leave -> join - if msg.Membership == gomatrixserverlib.Join { + if msg.Membership == spec.Join { return } // Notify any active sync requests that the invite has been retired. s.inviteStream.Advance(pduPos) - s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) + validRoomID, err := spec.NewRoomID(msg.RoomID) + if err != nil { + log.WithFields(log.Fields{ + "event_id": msg.EventID, + "room_id": msg.RoomID, + log.ErrorKey: err, + }).Errorf("roomID is invalid") + return + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, msg.TargetSenderID) + if err != nil || userID == nil { + log.WithFields(log.Fields{ + "event_id": msg.EventID, + "sender_id": msg.TargetSenderID, + log.ErrorKey: err, + }).Errorf("failed to find userID for sender") + return + } + s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String()) } func (s *OutputRoomEventConsumer) onNewPeek( @@ -495,7 +544,8 @@ func (s *OutputRoomEventConsumer) onPurgeRoom( } } -func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.HeaderedEvent) (*gomatrixserverlib.HeaderedEvent, error) { +func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) (*rstypes.HeaderedEvent, error) { + event.StateKeyResolved = event.StateKey() if event.StateKey() == nil { return event, nil } @@ -515,6 +565,29 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head return event, err } + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return event, err + } + + if event.StateKey() != nil { + if *event.StateKey() != "" { + var sku *spec.UserID + sku, err = s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, spec.SenderID(stateKey)) + if err == nil && sku != nil { + sKey := sku.String() + event.StateKeyResolved = &sKey + } + } + } + + userID, err := s.rsAPI.QueryUserIDForSender(s.ctx, *validRoomID, event.SenderID()) + if err != nil { + return event, err + } + + event.UserID = *userID + if prevEvent == nil || prevEvent.EventID() == event.EventID() { return event, nil } @@ -522,15 +595,15 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head prev := types.PrevEventRef{ PrevContent: prevEvent.Content(), ReplacesState: prevEvent.EventID(), - PrevSender: prevEvent.Sender(), + PrevSenderID: string(prevEvent.SenderID()), } - event.Event, err = event.SetUnsigned(prev) + event.PDU, err = event.SetUnsigned(prev) succeeded = true return event, err } -func (s *OutputRoomEventConsumer) writeFTS(ev *gomatrixserverlib.HeaderedEvent, pduPosition types.StreamPosition) error { +func (s *OutputRoomEventConsumer) writeFTS(ev *rstypes.HeaderedEvent, pduPosition types.StreamPosition) error { if !s.cfg.Fulltext.Enabled { return nil } @@ -544,11 +617,11 @@ func (s *OutputRoomEventConsumer) writeFTS(ev *gomatrixserverlib.HeaderedEvent, switch ev.Type() { case "m.room.message": e.Content = gjson.GetBytes(ev.Content(), "body").String() - case gomatrixserverlib.MRoomName: + case spec.MRoomName: e.Content = gjson.GetBytes(ev.Content(), "name").String() - case gomatrixserverlib.MRoomTopic: + case spec.MRoomTopic: e.Content = gjson.GetBytes(ev.Content(), "topic").String() - case gomatrixserverlib.MRoomRedaction: + case spec.MRoomRedaction: log.Tracef("Redacting event: %s", ev.Redacts()) if err := s.fts.Delete(ev.Redacts()); err != nil { return fmt.Errorf("failed to delete entry from fulltext index: %w", err) diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 32208c5850..7f387dc09d 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -20,6 +20,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -43,7 +44,7 @@ type OutputSendToDeviceEventConsumer struct { topic string db storage.Database userAPI api.SyncKeyAPI - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool stream streams.StreamProvider notifier *notifier.Notifier } diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index ee695f0f54..ce6846ca48 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -20,11 +20,13 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" ) @@ -50,7 +52,7 @@ var calculateHistoryVisibilityDuration = prometheus.NewHistogramVec( ) var historyVisibilityPriority = map[gomatrixserverlib.HistoryVisibility]uint8{ - gomatrixserverlib.WorldReadable: 0, + spec.WorldReadable: 0, gomatrixserverlib.HistoryVisibilityShared: 1, gomatrixserverlib.HistoryVisibilityInvited: 2, gomatrixserverlib.HistoryVisibilityJoined: 3, @@ -72,23 +74,23 @@ func (ev eventVisibility) allowed() (allowed bool) { return true case gomatrixserverlib.HistoryVisibilityJoined: // If the user’s membership was join, allow. - if ev.membershipAtEvent == gomatrixserverlib.Join { + if ev.membershipAtEvent == spec.Join { return true } return false case gomatrixserverlib.HistoryVisibilityShared: // If the user’s membership was join, allow. // If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow. - if ev.membershipAtEvent == gomatrixserverlib.Join || ev.membershipCurrent == gomatrixserverlib.Join { + if ev.membershipAtEvent == spec.Join || ev.membershipCurrent == spec.Join { return true } return false case gomatrixserverlib.HistoryVisibilityInvited: // If the user’s membership was join, allow. - if ev.membershipAtEvent == gomatrixserverlib.Join { + if ev.membershipAtEvent == spec.Join { return true } - if ev.membershipAtEvent == gomatrixserverlib.Invite { + if ev.membershipAtEvent == spec.Invite { return true } return false @@ -97,16 +99,16 @@ func (ev eventVisibility) allowed() (allowed bool) { } } -// ApplyHistoryVisibilityFilter applies the room history visibility filter on gomatrixserverlib.HeaderedEvents. +// ApplyHistoryVisibilityFilter applies the room history visibility filter on types.HeaderedEvents. // Returns the filtered events and an error, if any. func ApplyHistoryVisibilityFilter( ctx context.Context, syncDB storage.DatabaseTransaction, rsAPI api.SyncRoomserverAPI, - events []*gomatrixserverlib.HeaderedEvent, + events []*types.HeaderedEvent, alwaysIncludeEventIDs map[string]struct{}, userID, endpoint string, -) ([]*gomatrixserverlib.HeaderedEvent, error) { +) ([]*types.HeaderedEvent, error) { if len(events) == 0 { return events, nil } @@ -119,7 +121,7 @@ func ApplyHistoryVisibilityFilter( } // Get the mapping from eventID -> eventVisibility - eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events)) + eventsFiltered := make([]*types.HeaderedEvent, 0, len(events)) visibilities := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) for _, ev := range events { evVis := visibilities[ev.EventID()] @@ -132,9 +134,21 @@ func ApplyHistoryVisibilityFilter( } } // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules") - if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(userID) { - eventsFiltered = append(eventsFiltered, ev) - continue + + user, err := spec.NewUserID(userID, true) + if err != nil { + return nil, err + } + roomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user) + if err == nil { + if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) { + eventsFiltered = append(eventsFiltered, ev) + continue + } } // Always allow history evVis events on boundaries. This is done // by setting the effective evVis to the least restrictive @@ -169,7 +183,7 @@ func ApplyHistoryVisibilityFilter( func visibilityForEvents( ctx context.Context, rsAPI api.SyncRoomserverAPI, - events []*gomatrixserverlib.HeaderedEvent, + events []*types.HeaderedEvent, userID, roomID string, ) map[string]eventVisibility { eventIDs := make([]string, len(events)) @@ -195,7 +209,7 @@ func visibilityForEvents( for _, event := range events { eventID := event.EventID() vis := eventVisibility{ - membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident + membershipAtEvent: spec.Leave, // default to leave, to not expose events by accident visibility: event.Visibility, } ev, ok := membershipResp.Membership[eventID] diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 17d63708a6..24ffcc0414 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -20,6 +20,7 @@ import ( keytypes "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -159,7 +160,7 @@ func TrackChangedUsers( RoomIDs: newlyLeftRooms, StateTuples: []gomatrixserverlib.StateKeyTuple{ { - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: "*", }, }, @@ -168,12 +169,20 @@ func TrackChangedUsers( if err != nil { return nil, nil, err } - for _, state := range stateRes.Rooms { + for roomID, state := range stateRes.Rooms { + validRoomID, roomErr := spec.NewRoomID(roomID) + if roomErr != nil { + continue + } for tuple, membership := range state { - if membership != gomatrixserverlib.Join { + if membership != spec.Join { continue } - queryRes.UserIDsToCount[tuple.StateKey]-- + user, queryErr := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey)) + if queryErr != nil || user == nil { + continue + } + queryRes.UserIDsToCount[user.String()]-- } } @@ -201,7 +210,7 @@ func TrackChangedUsers( RoomIDs: newlyJoinedRooms, StateTuples: []gomatrixserverlib.StateKeyTuple{ { - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: "*", }, }, @@ -210,14 +219,22 @@ func TrackChangedUsers( if err != nil { return nil, left, err } - for _, state := range stateRes.Rooms { + for roomID, state := range stateRes.Rooms { + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + continue + } for tuple, membership := range state { - if membership != gomatrixserverlib.Join { + if membership != spec.Join { continue } // new user who we weren't previously sharing rooms with if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok { - changed = append(changed, tuple.StateKey) // changed is returned + user, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey)) + if err != nil || user == nil { + continue + } + changed = append(changed, user.String()) // changed is returned } } } @@ -283,7 +300,7 @@ func membershipEventPresent(events []synctypes.ClientEvent, userID string) bool for _, ev := range events { // it's enough to know that we have our member event here, don't need to check membership content // as it's implied by being in the respective section of the sync response. - if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { + if ev.Type == spec.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { // ignore e.g. join -> join changes if gjson.GetBytes(ev.Unsigned, "prev_content.membership").Str == gjson.GetBytes(ev.Content, "membership").Str { continue @@ -302,7 +319,7 @@ func membershipEventPresent(events []synctypes.ClientEvent, userID string) bool func membershipEvents(res *types.Response) (joinUserIDs, leaveUserIDs []string) { for _, room := range res.Rooms.Join { for _, ev := range room.Timeline.Events { - if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil { + if ev.Type == spec.MRoomMember && ev.StateKey != nil { if strings.Contains(string(ev.Content), `"join"`) { joinUserIDs = append(joinUserIDs, *ev.StateKey) } else if strings.Contains(string(ev.Content), `"invite"`) { diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index f775276fee..3f5e990c47 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/roomserver/api" @@ -33,20 +34,16 @@ func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *userapi.Perform func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} // PerformClaimKeys claims one-time keys for use in pre-key messages -func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *userapi.PerformClaimKeysRequest, res *userapi.PerformClaimKeysResponse) error { - return nil +func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *userapi.PerformClaimKeysRequest, res *userapi.PerformClaimKeysResponse) { } func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *userapi.PerformDeleteKeysRequest, res *userapi.PerformDeleteKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *userapi.PerformUploadDeviceKeysRequest, res *userapi.PerformUploadDeviceKeysResponse) error { - return nil +func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *userapi.PerformUploadDeviceKeysRequest, res *userapi.PerformUploadDeviceKeysResponse) { } -func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *userapi.PerformUploadDeviceSignaturesRequest, res *userapi.PerformUploadDeviceSignaturesResponse) error { - return nil +func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *userapi.PerformUploadDeviceSignaturesRequest, res *userapi.PerformUploadDeviceSignaturesResponse) { } -func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *userapi.QueryKeysRequest, res *userapi.QueryKeysResponse) error { - return nil +func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *userapi.QueryKeysRequest, res *userapi.QueryKeysResponse) { } func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error { return nil @@ -59,8 +56,7 @@ func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.Query return nil } -func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *userapi.QuerySignaturesRequest, res *userapi.QuerySignaturesResponse) error { - return nil +func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *userapi.QuerySignaturesRequest, res *userapi.QuerySignaturesResponse) { } type mockRoomserverAPI struct { @@ -68,6 +64,10 @@ type mockRoomserverAPI struct { roomIDToJoinedMembers map[string][]string } +func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + // QueryRoomsForUser retrieves a list of room IDs matching the given query. func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { return nil @@ -76,12 +76,12 @@ func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.Quer // QueryBulkStateContent does a bulk query for state event content in the given rooms. func (s *mockRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) - if req.AllowWildcards && len(req.StateTuples) == 1 && req.StateTuples[0].EventType == gomatrixserverlib.MRoomMember && req.StateTuples[0].StateKey == "*" { + if req.AllowWildcards && len(req.StateTuples) == 1 && req.StateTuples[0].EventType == spec.MRoomMember && req.StateTuples[0].StateKey == "*" { for _, roomID := range req.RoomIDs { res.Rooms[roomID] = make(map[gomatrixserverlib.StateKeyTuple]string) for _, userID := range s.roomIDToJoinedMembers[roomID] { res.Rooms[roomID][gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomMember, + EventType: spec.MRoomMember, StateKey: userID, }] = "join" } diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 32c6f04e07..1c6318e638 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -20,9 +20,11 @@ import ( "time" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -35,7 +37,8 @@ import ( // the event, but the token has already advanced by the time they fetch it, resulting // in missed events. type Notifier struct { - lock *sync.RWMutex + lock *sync.RWMutex + rsAPI api.SyncRoomserverAPI // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine roomIDToJoinedUsers map[string]*userIDSet // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine @@ -54,8 +57,9 @@ type Notifier struct { // NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier() *Notifier { +func NewNotifier(rsAPI api.SyncRoomserverAPI) *Notifier { return &Notifier{ + rsAPI: rsAPI, roomIDToJoinedUsers: make(map[string]*userIDSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet), userDeviceStreams: make(map[string]map[string]*UserDeviceStream), @@ -78,7 +82,7 @@ func (n *Notifier) SetCurrentPosition(currPos types.StreamingToken) { // OnNewEvent is called when a new event is received from the room server. Must only be // called from a single goroutine, to avoid races between updates which could set the // current sync position incorrectly. -// Chooses which user sync streams to update by a provided *gomatrixserverlib.Event +// Chooses which user sync streams to update by a provided gomatrixserverlib.PDU // (based on the users in the event's room), // a roomID directly, or a list of user IDs, prioritised by parameter ordering. // posUpdate contains the latest position(s) for one or more types of events. @@ -86,7 +90,7 @@ func (n *Notifier) SetCurrentPosition(currPos types.StreamingToken) { // Typically a consumer supplies a posUpdate with the latest sync position for the // event type it handles, leaving other fields as 0. func (n *Notifier) OnNewEvent( - ev *gomatrixserverlib.HeaderedEvent, roomID string, userIDs []string, + ev *rstypes.HeaderedEvent, roomID string, userIDs []string, posUpdate types.StreamingToken, ) { // update the current position then notify relevant /sync streams. @@ -97,32 +101,45 @@ func (n *Notifier) OnNewEvent( n._removeEmptyUserStreams() if ev != nil { + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + log.WithError(err).WithField("event_id", ev.EventID()).Errorf( + "Notifier.OnNewEvent: RoomID is invalid", + ) + return + } // Map this event's room_id to a list of joined users, and wake them up. usersToNotify := n._joinedUsers(ev.RoomID()) // Map this event's room_id to a list of peeking devices, and wake them up. peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) // If this is an invite, also add in the invitee to this list. if ev.Type() == "m.room.member" && ev.StateKey() != nil { - targetUserID := *ev.StateKey() - membership, err := ev.Membership() + targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), *validRoomID, spec.SenderID(*ev.StateKey())) if err != nil { log.WithError(err).WithField("event_id", ev.EventID()).Errorf( - "Notifier.OnNewEvent: Failed to unmarshal member event", + "Notifier.OnNewEvent: Failed to find the userID for this event", ) } else { - // Keep the joined user map up-to-date - switch membership { - case gomatrixserverlib.Invite: - usersToNotify = append(usersToNotify, targetUserID) - case gomatrixserverlib.Join: - // Manually append the new user's ID so they get notified - // along all members in the room - usersToNotify = append(usersToNotify, targetUserID) - n._addJoinedUser(ev.RoomID(), targetUserID) - case gomatrixserverlib.Leave: - fallthrough - case gomatrixserverlib.Ban: - n._removeJoinedUser(ev.RoomID(), targetUserID) + membership, err := ev.Membership() + if err != nil { + log.WithError(err).WithField("event_id", ev.EventID()).Errorf( + "Notifier.OnNewEvent: Failed to unmarshal member event", + ) + } else { + // Keep the joined user map up-to-date + switch membership { + case spec.Invite: + usersToNotify = append(usersToNotify, targetUserID.String()) + case spec.Join: + // Manually append the new user's ID so they get notified + // along all members in the room + usersToNotify = append(usersToNotify, targetUserID.String()) + n._addJoinedUser(ev.RoomID(), targetUserID.String()) + case spec.Leave: + fallthrough + case spec.Ban: + n._removeJoinedUser(ev.RoomID(), targetUserID.String()) + } } } } diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go index b06313712d..f86301a060 100644 --- a/syncapi/notifier/notifier_test.go +++ b/syncapi/notifier/notifier_test.go @@ -22,16 +22,18 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) var ( - randomMessageEvent gomatrixserverlib.HeaderedEvent - aliceInviteBobEvent gomatrixserverlib.HeaderedEvent - bobLeaveEvent gomatrixserverlib.HeaderedEvent + randomMessageEvent rstypes.HeaderedEvent + aliceInviteBobEvent rstypes.HeaderedEvent + bobLeaveEvent rstypes.HeaderedEvent syncPositionVeryOld = types.StreamingToken{PDUPosition: 5} syncPositionBefore = types.StreamingToken{PDUPosition: 11} syncPositionAfter = types.StreamingToken{PDUPosition: 12} @@ -105,9 +107,15 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { } } +type TestRoomServer struct{ api.SyncRoomserverAPI } + +func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld)) if err != nil { @@ -118,7 +126,7 @@ func TestImmediateNotification(t *testing.T) { // Test that new events to a joined room unblocks the request. func TestNewEventAndJoinedToRoom(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -144,7 +152,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { } func TestCorrectStream(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) stream := lockedFetchUserStream(n, bob, bobDev) if stream.UserID != bob { @@ -156,7 +164,7 @@ func TestCorrectStream(t *testing.T) { } func TestCorrectStreamWakeup(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) awoken := make(chan string) @@ -184,7 +192,7 @@ func TestCorrectStreamWakeup(t *testing.T) { // Test that an invite unblocks the request func TestNewInviteEventForUser(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -241,7 +249,7 @@ func TestEDUWakeup(t *testing.T) { // Test that all blocked requests get woken up on a new event. func TestMultipleRequestWakeup(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -278,7 +286,7 @@ func TestMultipleRequestWakeup(t *testing.T) { func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // listen as bob. Make bob leave room. Make alice send event to room. // Make sure alice gets woken up only and not bob as well. - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, diff --git a/syncapi/producers/federationapi_presence.go b/syncapi/producers/federationapi_presence.go index dc03457e3f..eab1b0b25e 100644 --- a/syncapi/producers/federationapi_presence.go +++ b/syncapi/producers/federationapi_presence.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" ) @@ -37,7 +37,7 @@ func (f *FederationAPIPresenceProducer) SendPresence( msg.Header.Set(jetstream.UserID, userID) msg.Header.Set("presence", presence.String()) msg.Header.Set("from_sync", "true") // only update last_active_ts and presence - msg.Header.Set("last_active_ts", strconv.Itoa(int(gomatrixserverlib.AsTimestamp(time.Now())))) + msg.Header.Set("last_active_ts", strconv.Itoa(int(spec.AsTimestamp(time.Now())))) if statusMsg != nil { msg.Header.Set("status_msg", *statusMsg) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index dd42c7ac41..649d77b41d 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -23,16 +23,17 @@ import ( "strconv" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" roomserver "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" ) @@ -55,7 +56,10 @@ func Context( ) util.JSONResponse { snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -73,7 +77,7 @@ func Context( } return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidParam(errMsg), + JSON: spec.InvalidParam(errMsg), Headers: nil, } } @@ -81,17 +85,27 @@ func Context( *filter.Rooms = append(*filter.Rooms, roomID) } + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Device UserID is invalid"), + } + } ctx := req.Context() membershipRes := roomserver.QueryMembershipForUserResponse{} - membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} + membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID} if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { logrus.WithError(err).Error("unable to query membership") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !membershipRes.RoomExists { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("room does not exist"), + JSON: spec.Forbidden("room does not exist"), } } @@ -112,19 +126,25 @@ func Context( if err == sql.ErrNoRows { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound(fmt.Sprintf("Event %s not found", eventID)), + JSON: spec.NotFound(fmt.Sprintf("Event %s not found", eventID)), } } logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // verify the user is allowed to see the context for this room/event startTime := time.Now() - filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } logrus.WithFields(logrus.Fields{ "duration": time.Since(startTime), @@ -133,27 +153,36 @@ func Context( if len(filteredEvents) == 0 { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("User is not allowed to query context"), + JSON: spec.Forbidden("User is not allowed to query context"), } } eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch before events") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch after events") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } startTime = time.Now() eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } logrus.WithFields(logrus.Fields{ @@ -165,30 +194,46 @@ func Context( state, err := snapshot.CurrentState(ctx, roomID, &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to fetch current room state") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - eventsBeforeClient := synctypes.HeaderedToClientEvents(eventsBeforeFiltered, synctypes.FormatAll) - eventsAfterClient := synctypes.HeaderedToClientEvents(eventsAfterFiltered, synctypes.FormatAll) + eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) newState := state if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := synctypes.HeaderedToClientEvents(allEvents, synctypes.FormatAll) + evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) if err != nil { logrus.WithError(err).Error("unable to load membership events") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } - ev := synctypes.HeaderedToClientEvent(&requestedEvent, synctypes.FormatAll) + ev := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, requestedEvent) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: synctypes.HeaderedToClientEvents(newState, synctypes.FormatAll), + State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), } if len(response.State) > filter.Limit { @@ -211,9 +256,9 @@ func Context( // and an error, if any. func applyHistoryVisibilityOnContextEvents( ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI, - eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent, + eventsBefore, eventsAfter []*rstypes.HeaderedEvent, userID string, -) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) { +) (filteredBefore, filteredAfter []*rstypes.HeaderedEvent, err error) { eventIDsBefore := make(map[string]struct{}, len(eventsBefore)) eventIDsAfter := make(map[string]struct{}, len(eventsAfter)) @@ -244,7 +289,7 @@ func applyHistoryVisibilityOnContextEvents( return filteredBefore, filteredAfter, nil } -func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { +func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, startEvents, endEvents []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) { if len(startEvents) > 0 { start, err = snapshot.EventPositionInTopology(ctx, startEvents[0].EventID()) if err != nil { @@ -264,7 +309,7 @@ func applyLazyLoadMembers( roomID string, events []synctypes.ClientEvent, lazyLoadCache caching.LazyLoadCache, -) ([]*gomatrixserverlib.HeaderedEvent, error) { +) ([]*rstypes.HeaderedEvent, error) { eventSenders := make(map[string]struct{}) // get members who actually send an event for _, e := range events { @@ -283,7 +328,7 @@ func applyLazyLoadMembers( // Query missing membership events filter := synctypes.DefaultStateFilter() filter.Senders = &wantUsers - filter.Types = &[]string{gomatrixserverlib.MRoomMember} + filter.Types = &[]string{spec.MRoomMember} memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter) if err != nil { return nil, err diff --git a/syncapi/routing/filter.go b/syncapi/routing/filter.go index 266ad4adcc..c4eecbdb8a 100644 --- a/syncapi/routing/filter.go +++ b/syncapi/routing/filter.go @@ -23,11 +23,11 @@ import ( "github.com/matrix-org/util" "github.com/tidwall/gjson" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} @@ -37,13 +37,16 @@ func GetFilter( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot get filters for other users"), + JSON: spec.Forbidden("Cannot get filters for other users"), } } localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } filter := synctypes.DefaultFilter() @@ -53,7 +56,7 @@ func GetFilter( // even though it is not correct. return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.NotFound("No such filter"), + JSON: spec.NotFound("No such filter"), } } @@ -76,14 +79,17 @@ func PutFilter( if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Cannot create filters for other users"), + JSON: spec.Forbidden("Cannot create filters for other users"), } } localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var filter synctypes.Filter @@ -93,14 +99,14 @@ func PutFilter( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be read. " + err.Error()), + JSON: spec.BadJSON("The request body could not be read. " + err.Error()), } } if err = json.Unmarshal(body, &filter); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + JSON: spec.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } // the filter `limit` is `int` which defaults to 0 if not set which is not what we want. We want to use the default @@ -115,14 +121,17 @@ func PutFilter( if err = filter.Validate(); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Invalid filter: " + err.Error()), + JSON: spec.BadJSON("Invalid filter: " + err.Error()), } } filterID, err := syncDB.PutFilter(req.Context(), localpart, &filter) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncDB.PutFilter failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } return util.JSONResponse{ diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 84986d3b3f..09c2aef02d 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -20,13 +20,13 @@ import ( "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) // GetEvent implements @@ -51,13 +51,19 @@ func GetEvent( }) if err != nil { logger.WithError(err).Error("GetEvent: syncDB.NewDatabaseTransaction failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } events, err := db.Events(ctx, []string{eventID}) if err != nil { logger.WithError(err).Error("GetEvent: syncDB.Events failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } // The requested event does not exist in our database @@ -65,7 +71,7 @@ func GetEvent( logger.Debugf("GetEvent: requested event doesn't exist locally") return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"), + JSON: spec.NotFound("The event was not found or you do not have permission to read this event"), } } @@ -81,7 +87,7 @@ func GetEvent( logger.WithError(err).Error("GetEvent: internal.ApplyHistoryVisibilityFilter failed") return util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.InternalServerError(), + JSON: spec.InternalServerError{}, } } @@ -91,12 +97,40 @@ func GetEvent( logger.WithField("event_count", len(events)).Debug("GetEvent: can't return the requested event") return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"), + JSON: spec.NotFound("The event was not found or you do not have permission to read this event"), } } + sender := spec.UserID{} + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("roomID is invalid"), + } + } + senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, events[0].SenderID()) + if err == nil && senderUserID != nil { + sender = *senderUserID + } + + sk := events[0].StateKey() + if sk != nil && *sk != "" { + evRoomID, err := spec.NewRoomID(events[0].RoomID()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("roomID is invalid"), + } + } + skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*events[0].StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.HeaderedToClientEvent(events[0], synctypes.FormatAll), + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk), } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 9ea660f596..5e5d0125fa 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -19,12 +19,13 @@ import ( "math" "net/http" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -58,34 +59,47 @@ func GetMemberships( syncDB storage.Database, rsAPI api.SyncRoomserverAPI, joinedOnly bool, membership, notMembership *string, at string, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Device UserID is invalid"), + } + } queryReq := api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: device.UserID, + UserID: *userID, } var queryRes api.QueryMembershipForUserResponse - if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed") - return jsonerror.InternalServerError() + if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil { + util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !queryRes.HasBeenInRoom { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You aren't a member of the room and weren't previously a member of the room."), + JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."), } } if joinedOnly && !queryRes.IsInRoom { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("You aren't a member of the room and weren't previously a member of the room."), + JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."), } } db, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } defer db.Rollback() // nolint: errcheck @@ -97,7 +111,10 @@ func GetMemberships( atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } } @@ -105,13 +122,19 @@ func GetMemberships( eventIDs, err := db.SelectMemberships(req.Context(), roomID, atToken, membership, notMembership) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("db.SelectMemberships failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } qryRes := &api.QueryEventsByIDResponse{} if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs, RoomID: roomID}, qryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } result := qryRes.Events @@ -123,9 +146,35 @@ func GetMemberships( var content databaseJoinedMember if err := json.Unmarshal(ev.Content(), &content); err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID()) + if err != nil || userID == nil { + util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), + } } - res.Joined[ev.Sender()] = joinedMember(content) + res.Joined[userID.String()] = joinedMember(content) } return util.JSONResponse{ Code: http.StatusOK, @@ -134,6 +183,8 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{synctypes.HeaderedToClientEvents(result, synctypes.FormatAll)}, + JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + })}, } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 036178b700..937e20ad8d 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -23,13 +23,14 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" @@ -80,7 +81,10 @@ func OnIncomingMessagesRequest( // request that requires backfilling from the roomserver or federation. snapshot, err := db.NewDatabaseTransaction(req.Context()) if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -88,19 +92,22 @@ func OnIncomingMessagesRequest( // check if the user has already forgotten about this room membershipResp, err := getMembershipForUser(req.Context(), roomID, device.UserID, rsAPI) if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if !membershipResp.RoomExists { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("room does not exist"), + JSON: spec.Forbidden("room does not exist"), } } if membershipResp.IsRoomForgotten { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("user already forgot about this room"), + JSON: spec.Forbidden("user already forgot about this room"), } } @@ -108,7 +115,7 @@ func OnIncomingMessagesRequest( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("unable to parse filter"), + JSON: spec.InvalidParam("unable to parse filter"), } } @@ -130,7 +137,7 @@ func OnIncomingMessagesRequest( if dir != "b" && dir != "f" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Bad or missing dir query parameter (should be either 'b' or 'f')"), + JSON: spec.MissingParam("Bad or missing dir query parameter (should be either 'b' or 'f')"), } } // A boolean is easier to handle in this case, especially since dir is sure @@ -143,14 +150,17 @@ func OnIncomingMessagesRequest( if streamToken, err = types.NewStreamTokenFromString(fromQuery); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()), + JSON: spec.InvalidParam("Invalid from parameter: " + err.Error()), } } else { fromStream = &streamToken from, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering) if err != nil { logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } } @@ -166,13 +176,16 @@ func OnIncomingMessagesRequest( if streamToken, err = types.NewStreamTokenFromString(toQuery); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()), + JSON: spec.InvalidParam("Invalid to parameter: " + err.Error()), } } else { to, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering) if err != nil { logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } } @@ -195,12 +208,12 @@ func OnIncomingMessagesRequest( if _, _, err = gomatrixserverlib.SplitID('!', roomID); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Bad room ID: " + err.Error()), + JSON: spec.MissingParam("Bad room ID: " + err.Error()), } } // If the user already left the room, grep events from before that - if membershipResp.Membership == gomatrixserverlib.Leave { + if membershipResp.Membership == spec.Leave { var token types.TopologyToken token, err = snapshot.EventPositionInTopology(req.Context(), membershipResp.EventID) if err != nil { @@ -228,10 +241,13 @@ func OnIncomingMessagesRequest( device: device, } - clientEvents, start, end, err := mReq.retrieveEvents() + clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } util.GetLogger(req.Context()).WithFields(logrus.Fields{ @@ -252,9 +268,14 @@ func OnIncomingMessagesRequest( membershipEvents, err := applyLazyLoadMembers(req.Context(), device, snapshot, roomID, clientEvents, lazyLoadCache) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to apply lazy loading") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - res.State = append(res.State, synctypes.HeaderedToClientEvents(membershipEvents, synctypes.FormatAll)...) + res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + })...) } // If we didn't return any events, set the end to an empty string, so it will be omitted @@ -275,9 +296,13 @@ func OnIncomingMessagesRequest( } func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { + fullUserID, err := spec.NewUserID(userID, true) + if err != nil { + return resp, err + } req := api.QueryMembershipForUserRequest{ RoomID: roomID, - UserID: userID, + UserID: *fullUserID, } if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { return api.QueryMembershipForUserResponse{}, err @@ -291,7 +316,7 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api. // homeserver in the room for older events. // Returns an error if there was an issue talking to the database or with the // remote homeserver. -func (r *messagesReq) retrieveEvents() ( +func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserverAPI) ( clientEvents []synctypes.ClientEvent, start, end types.TopologyToken, err error, ) { @@ -303,7 +328,7 @@ func (r *messagesReq) retrieveEvents() ( return []synctypes.ClientEvent{}, emptyToken, emptyToken, err } - var events []*gomatrixserverlib.HeaderedEvent + var events []*rstypes.HeaderedEvent util.GetLogger(r.ctx).WithFields(logrus.Fields{ "start": r.from, "end": r.to, @@ -342,8 +367,8 @@ func (r *messagesReq) retrieveEvents() ( // Sort the events to ensure we send them in the right order. if r.backwardOrdering { // This reverses the array from old->new to new->old - reversed := func(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { - out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) + reversed := func(in []*rstypes.HeaderedEvent) []*rstypes.HeaderedEvent { + out := make([]*rstypes.HeaderedEvent, len(in)) for i := 0; i < len(in); i++ { out[i] = in[len(in)-i-1] } @@ -364,13 +389,15 @@ func (r *messagesReq) retrieveEvents() ( "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return synctypes.HeaderedToClientEvents(filteredEvents, synctypes.FormatAll), start, end, err + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), start, end, err } -func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { +func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) { if r.backwardOrdering { start = *r.from - if events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate { + if events[len(events)-1].Type() == spec.MRoomCreate { // NOTSPEC: We've hit the beginning of the room so there's really nowhere // else to go. This seems to fix Element iOS from looping on /messages endlessly. end = types.TopologyToken{} @@ -406,7 +433,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st // Returns an error if there was an issue talking with the database or // backfilling. func (r *messagesReq) handleEmptyEventsSlice() ( - events []*gomatrixserverlib.HeaderedEvent, err error, + events []*rstypes.HeaderedEvent, err error, ) { backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID) @@ -420,7 +447,7 @@ func (r *messagesReq) handleEmptyEventsSlice() ( } else { // If not, it means the slice was empty because we reached the room's // creation, so return an empty slice. - events = []*gomatrixserverlib.HeaderedEvent{} + events = []*rstypes.HeaderedEvent{} } return @@ -432,7 +459,7 @@ func (r *messagesReq) handleEmptyEventsSlice() ( // through backfilling if needed. // Returns an error if there was an issue while backfilling. func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent) ( - events []*gomatrixserverlib.HeaderedEvent, err error, + events []*rstypes.HeaderedEvent, err error, ) { // Check if we have enough events. isSetLargeEnough := len(streamEvents) >= r.filter.Limit @@ -460,7 +487,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent // Backfill is needed if we've reached a backward extremity and need more // events. It's only needed if the direction is backward. if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering { - var pdus []*gomatrixserverlib.HeaderedEvent + var pdus []*rstypes.HeaderedEvent // Only ask the remote server for enough events to reach the limit. pdus, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit-len(streamEvents)) if err != nil { @@ -472,13 +499,13 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent } // Append the events ve previously retrieved locally. - events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...) + events = append(events, r.snapshot.StreamEventsToEvents(r.ctx, nil, streamEvents, r.rsAPI)...) sort.Sort(eventsByDepth(events)) return } -type eventsByDepth []*gomatrixserverlib.HeaderedEvent +type eventsByDepth []*rstypes.HeaderedEvent func (e eventsByDepth) Len() int { return len(e) @@ -499,7 +526,7 @@ func (e eventsByDepth) Less(i, j int) bool { // event, or if there is no remote homeserver to contact. // Returns an error if there was an issue with retrieving the list of servers in // the room or sending the request. -func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]string, limit int) ([]*gomatrixserverlib.HeaderedEvent, error) { +func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]string, limit int) ([]*rstypes.HeaderedEvent, error) { var res api.PerformBackfillResponse err := r.rsAPI.PerformBackfill(context.Background(), &api.PerformBackfillRequest{ RoomID: roomID, @@ -532,7 +559,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] _, err = r.db.WriteEvent( context.Background(), res.Events[i], - []*gomatrixserverlib.HeaderedEvent{}, + []*rstypes.HeaderedEvent{}, []string{}, []string{}, nil, true, diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 79533883f5..17933b2fb0 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -18,18 +18,18 @@ import ( "net/http" "strconv" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type RelationsResponse struct { @@ -73,14 +73,17 @@ func Relations( if dir != "b" && dir != "f" { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Bad or missing dir query parameter (should be either 'b' or 'f')"), + JSON: spec.MissingParam("Bad or missing dir query parameter (should be either 'b' or 'f')"), } } snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { logrus.WithError(err).Error("Failed to get snapshot for relations") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -96,7 +99,7 @@ func Relations( return util.ErrorResponse(err) } - headeredEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events)) + headeredEvents := make([]*rstypes.HeaderedEvent, 0, len(events)) for _, event := range events { headeredEvents = append(headeredEvents, event.HeaderedEvent) } @@ -107,13 +110,32 @@ func Relations( return util.ErrorResponse(err) } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.ErrorResponse(err) + } + // Convert the events into client events, and optionally filter based on the event // type if it was specified. res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.Event, synctypes.FormatAll), + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk), ) } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index b1283247b5..8542c0b73e 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -18,10 +18,9 @@ import ( "net/http" "github.com/gorilla/mux" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/httputil" @@ -96,7 +95,7 @@ func Setup( }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/context/{eventId}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("context", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -112,7 +111,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relations", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -126,7 +125,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relation_type", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -140,7 +139,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAuthAPI("relation_type_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -158,18 +157,21 @@ func Setup( if !cfg.Fulltext.Enabled { return util.JSONResponse{ Code: http.StatusNotImplemented, - JSON: jsonerror.Unknown("Search has been disabled by the server administrator."), + JSON: spec.Unknown("Search has been disabled by the server administrator."), } } var nextBatch *string if err := req.ParseForm(); err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if req.Form.Has("next_batch") { nb := req.FormValue("next_batch") nextBatch = &nb } - return Search(req, device, syncDB, fts, nextBatch) + return Search(req, device, syncDB, fts, nextBatch, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -201,7 +203,7 @@ func Setup( return util.ErrorResponse(err) } at := req.URL.Query().Get("at") - membership := gomatrixserverlib.Join + membership := spec.Join return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, true, &membership, nil, at) }), ).Methods(http.MethodGet, http.MethodOptions) diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 15cb2f9b8f..d892b604a5 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -23,21 +23,23 @@ import ( "github.com/blevesearch/bleve/v2/search" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/userapi/api" ) // nolint:gocyclo -func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string) util.JSONResponse { +func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string, rsAPI roomserverAPI.SyncRoomserverAPI) util.JSONResponse { start := time.Now() var ( searchReq SearchRequest @@ -54,7 +56,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts if from != nil && *from != "" { nextBatch, err = strconv.Atoi(*from) if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } } @@ -64,7 +69,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -72,12 +80,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts // only search rooms the user is actually joined to joinedRooms, err := snapshot.RoomIDsWithMembership(ctx, device.UserID, "join") if err != nil { - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } if len(joinedRooms) == 0 { return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.NotFound("User not joined to any rooms."), + JSON: spec.NotFound("User not joined to any rooms."), } } joinedRoomsMap := make(map[string]struct{}, len(joinedRooms)) @@ -98,7 +109,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts if len(rooms) == 0 { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Unknown("User not allowed to search in this room(s)."), + JSON: spec.Unknown("User not allowed to search in this room(s)."), } } @@ -114,7 +125,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts ) if err != nil { logrus.WithError(err).Error("failed to search fulltext") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } logrus.Debugf("Search took %s", result.Took) @@ -154,7 +168,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts evs, err := syncDB.Events(ctx, wantEvents) if err != nil { logrus.WithError(err).Error("failed to get events from database") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } groups := make(map[string]RoomResult) @@ -172,21 +189,38 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts eventsBefore, eventsAfter, err := contextEvents(ctx, snapshot, event, roomFilter, searchReq) if err != nil { logrus.WithError(err).Error("failed to get context events") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } startToken, endToken, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter) if err != nil { logrus.WithError(err).Error("failed to get start/end") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { - profile, ok := knownUsersProfiles[event.Sender()] + validRoomID, roomErr := spec.NewRoomID(ev.RoomID()) + if err != nil { + logrus.WithError(roomErr).WithField("room_id", ev.RoomID()).Warn("failed to query userprofile") + continue + } + userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID()) + if queryErr != nil { + logrus.WithError(queryErr).WithField("sender_id", ev.SenderID()).Warn("failed to query userprofile") + continue + } + + profile, ok := knownUsersProfiles[userID.String()] if !ok { - stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender()) - if err != nil { - logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") + stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, string(ev.SenderID())) + if stateErr != nil { + logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue } if stateEvent == nil { @@ -196,21 +230,44 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str, DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str, } - knownUsersProfiles[event.Sender()] = profile + knownUsersProfiles[userID.String()] = profile } - profileInfos[ev.Sender()] = profile + profileInfos[userID.String()] = profile } + sender := spec.UserID{} + validRoomID, roomErr := spec.NewRoomID(event.RoomID()) + if err != nil { + logrus.WithError(roomErr).WithField("room_id", event.RoomID()).Warn("failed to query userprofile") + continue + } + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } results = append(results, Result{ Context: SearchContextResponse{ - Start: startToken.String(), - End: endToken.String(), - EventsAfter: synctypes.HeaderedToClientEvents(eventsAfter, synctypes.FormatSync), - EventsBefore: synctypes.HeaderedToClientEvents(eventsBefore, synctypes.FormatSync), - ProfileInfo: profileInfos, + Start: startToken.String(), + End: endToken.String(), + EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }), + EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }), + ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.HeaderedToClientEvent(event, synctypes.FormatAll), + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk), }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) @@ -220,9 +277,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to get current state") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } - stateForRooms[event.RoomID()] = synctypes.HeaderedToClientEvents(state, synctypes.FormatSync) + stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }) } } @@ -262,10 +324,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts func contextEvents( ctx context.Context, snapshot storage.DatabaseTransaction, - event *gomatrixserverlib.HeaderedEvent, + event *types.HeaderedEvent, roomFilter *synctypes.RoomEventFilter, searchReq SearchRequest, -) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) { +) ([]*types.HeaderedEvent, []*types.HeaderedEvent, error) { id, _, err := snapshot.SelectContextEvent(ctx, event.RoomID(), event.EventID()) if err != nil { logrus.WithError(err).Error("failed to query context event") diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index 3e3fcdbd07..ab47da1297 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -9,6 +10,8 @@ import ( "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" @@ -16,9 +19,16 @@ import ( "github.com/matrix-org/dendrite/test/testrig" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) +type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } + +func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + func TestSearch(t *testing.T) { alice := test.NewUser(t) aliceDevice := userapi.Device{UserID: alice.ID} @@ -214,12 +224,13 @@ func TestSearch(t *testing.T) { // store the events in the database var sp types.StreamPosition for _, x := range room.Events() { - var stateEvents []*gomatrixserverlib.HeaderedEvent + var stateEvents []*rstypes.HeaderedEvent var stateEventIDs []string - if x.Type() == gomatrixserverlib.MRoomMember { + if x.Type() == spec.MRoomMember { stateEvents = append(stateEvents, x) stateEventIDs = append(stateEventIDs, x.EventID()) } + x.StateKeyResolved = x.StateKey() sp, err = db.WriteEvent(processCtx.Context(), x, stateEvents, stateEventIDs, nil, nil, false, gomatrixserverlib.HistoryVisibilityShared) assert.NoError(t, err) if x.Type() != "m.room.message" { @@ -245,7 +256,7 @@ func TestSearch(t *testing.T) { assert.NoError(t, err) req := httptest.NewRequest(http.MethodPost, "/", reqBody) - res := Search(req, tc.device, db, fts, tc.from) + res := Search(req, tc.device, db, fts, tc.from, &FakeSyncRoomserverAPI{}) if !tc.wantOK && !res.Is2xx() { return } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index a53822f7eb..09ce02396d 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -18,10 +18,12 @@ import ( "context" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" @@ -41,16 +43,16 @@ type DatabaseTransaction interface { MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) - CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) - GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) - GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) + CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error) + GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error) + GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *synctypes.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) - GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error) + GetBackwardTopologyPos(ctx context.Context, events []*rstypes.HeaderedEvent) (types.TopologyToken, error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) - InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) + InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*rstypes.HeaderedEvent, map[string]*rstypes.HeaderedEvent, types.StreamPosition, error) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. @@ -64,15 +66,15 @@ type DatabaseTransaction interface { // If an event is not found in the database then it will be omitted from the list. // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. - Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) + Events(ctx context.Context, eventIDs []string) ([]*rstypes.HeaderedEvent, error) // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error - GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*rstypes.HeaderedEvent, error) // GetStateEventsForRoom fetches the state events for a given room. // Returns an empty slice if no state events could be found for this room. // Returns an error if there was an issue with the retrieval. - GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) + GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter) (stateEvents []*rstypes.HeaderedEvent, err error) // GetAccountDataInRange returns all account data for a given user inserted or // updated between two given positions // Returns a map following the format data[roomID] = []dataTypes @@ -88,15 +90,15 @@ type DatabaseTransaction interface { // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. - StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent + StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the // relevant events within the given ranges for the supplied user ID and device ID. SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) // GetRoomReceipts gets all receipts for a given roomID GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) - SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) - SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) - SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + SelectContextEvent(ctx context.Context, roomID, eventID string) (int, rstypes.HeaderedEvent, error) + SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) ([]*rstypes.HeaderedEvent, error) + SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) (int, []*rstypes.HeaderedEvent, error) StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found @@ -123,11 +125,11 @@ type Database interface { // If an event is not found in the database then it will be omitted from the list. // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. - Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) + Events(ctx context.Context, eventIDs []string) ([]*rstypes.HeaderedEvent, error) // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races // when generating the sync stream position for this event. Returns the sync stream position for the inserted event. // Returns an error if there was a problem inserting this event. - WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []*gomatrixserverlib.HeaderedEvent, + WriteEvent(ctx context.Context, ev *rstypes.HeaderedEvent, addStateEvents []*rstypes.HeaderedEvent, addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool, historyVisibility gomatrixserverlib.HistoryVisibility, ) (types.StreamPosition, error) @@ -146,7 +148,7 @@ type Database interface { // AddInviteEvent stores a new invite event for a user. // If the invite was successfully stored this returns the stream ID it was stored at. // Returns an error if there was a problem communicating with the database. - AddInviteEvent(ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent) (types.StreamPosition, error) + AddInviteEvent(ctx context.Context, inviteEvent *rstypes.HeaderedEvent) (types.StreamPosition, error) // RetireInviteEvent removes an old invite event from the database. Returns the new position of the retired invite. // Returns an error if there was a problem communicating with the database. RetireInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error) @@ -173,12 +175,12 @@ type Database interface { // goes wrong. PutFilter(ctx context.Context, localpart string, filter *synctypes.Filter) (string, error) // RedactEvent wipes an event in the database and sets the unsigned.redacted_because key to the redaction event - RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error + RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *rstypes.HeaderedEvent, querier api.QuerySenderIDAPI) error // StoreReceipt stores new receipt events - StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) + StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error - ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) - UpdateRelations(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error + ReIndex(ctx context.Context, limit, afterID int64) (map[int64]rstypes.HeaderedEvent, error) + UpdateRelations(ctx context.Context, event *rstypes.HeaderedEvent) error RedactRelations(ctx context.Context, roomID, redactedEventID string) error SelectMemberships( ctx context.Context, @@ -189,7 +191,7 @@ type Database interface { type Presence interface { GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) - UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) + UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS spec.Timestamp, fromSync bool) (types.StreamPosition, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter synctypes.EventFilter) (map[string]*types.PresenceInternal, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) ExpirePresence(ctx context.Context) ([]types.PresenceNotify, error) diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index b05477585e..112fa9d4a9 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -24,11 +24,13 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const currentRoomStateSchema = ` @@ -273,14 +275,14 @@ func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, stateFilter *synctypes.StateFilter, excludeEventIDs []string, -) ([]*gomatrixserverlib.HeaderedEvent, error) { +) ([]*rstypes.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) senders, notSenders := getSendersStateFilterFilter(stateFilter) // We're going to query members later, so remove them from this request if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers { - notTypes := &[]string{gomatrixserverlib.MRoomMember} + notTypes := &[]string{spec.MRoomMember} if stateFilter.NotTypes != nil { - *stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember) + *stateFilter.NotTypes = append(*stateFilter.NotTypes, spec.MRoomMember) } else { stateFilter.NotTypes = notTypes } @@ -319,7 +321,7 @@ func (s *currentRoomStateStatements) DeleteRoomStateForRoom( func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, - event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, + event *rstypes.HeaderedEvent, membership *string, addedAt types.StreamPosition, ) error { // Parse content as JSON and search for an "url" key containsURL := false @@ -341,9 +343,9 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.RoomID(), event.EventID(), event.Type(), - event.Sender(), + event.UserID.String(), containsURL, - *event.StateKey(), + *event.StateKeyResolved, headeredJSON, membership, addedAt, @@ -377,8 +379,8 @@ func currentRoomStateRowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, er return nil, err } // TODO: Handle redacted events - var ev gomatrixserverlib.HeaderedEvent - if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + var ev rstypes.HeaderedEvent + if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } @@ -393,8 +395,8 @@ func currentRoomStateRowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, er return events, nil } -func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { - result := []*gomatrixserverlib.HeaderedEvent{} +func rowsToEvents(rows *sql.Rows) ([]*rstypes.HeaderedEvent, error) { + result := []*rstypes.HeaderedEvent{} for rows.Next() { var eventID string var eventBytes []byte @@ -402,8 +404,8 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { return nil, err } // TODO: Handle redacted events - var ev gomatrixserverlib.HeaderedEvent - if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + var ev rstypes.HeaderedEvent + if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } result = append(result, &ev) @@ -413,7 +415,7 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { func (s *currentRoomStateStatements) SelectStateEvent( ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string, -) (*gomatrixserverlib.HeaderedEvent, error) { +) (*rstypes.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt) var res []byte err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) @@ -423,7 +425,7 @@ func (s *currentRoomStateStatements) SelectStateEvent( if err != nil { return nil, err } - var ev gomatrixserverlib.HeaderedEvent + var ev rstypes.HeaderedEvent if err = json.Unmarshal(res, &ev); err != nil { return nil, err } diff --git a/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go b/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go index d68ed8d5f2..37660ee9d6 100644 --- a/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go +++ b/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -79,7 +80,7 @@ func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gom defer rows.Close() // nolint: errcheck var eventBytes []byte var roomID string - var event gomatrixserverlib.HeaderedEvent + var event types.HeaderedEvent var hisVis gomatrixserverlib.HistoryVisibility historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility) for rows.Next() { diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index 151bffa5d5..7b8d2d7336 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -22,9 +22,9 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const inviteEventsSchema = ` @@ -89,7 +89,7 @@ func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { } func (s *inviteEventsStatements) InsertInviteEvent( - ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent, + ctx context.Context, txn *sql.Tx, inviteEvent *rstypes.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { var headeredJSON []byte headeredJSON, err = json.Marshal(inviteEvent) @@ -101,7 +101,7 @@ func (s *inviteEventsStatements) InsertInviteEvent( ctx, inviteEvent.RoomID(), inviteEvent.EventID(), - *inviteEvent.StateKey(), + inviteEvent.UserID.String(), headeredJSON, ).Scan(&streamPos) return @@ -119,7 +119,7 @@ func (s *inviteEventsStatements) DeleteInviteEvent( // active invites for the target user ID in the supplied range. func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, -) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) { +) (map[string]*rstypes.HeaderedEvent, map[string]*rstypes.HeaderedEvent, types.StreamPosition, error) { var lastPos types.StreamPosition stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) @@ -127,8 +127,8 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( return nil, nil, lastPos, err } defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") - result := map[string]*gomatrixserverlib.HeaderedEvent{} - retired := map[string]*gomatrixserverlib.HeaderedEvent{} + result := map[string]*rstypes.HeaderedEvent{} + retired := map[string]*rstypes.HeaderedEvent{} for rows.Next() { var ( id types.StreamPosition @@ -151,7 +151,7 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( continue } - var event *gomatrixserverlib.HeaderedEvent + var event *rstypes.HeaderedEvent if err := json.Unmarshal(eventJSON, &event); err != nil { return nil, nil, lastPos, err } diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 47833893aa..09b47432b7 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -19,9 +19,8 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/sqlutil" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -100,7 +99,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { } func (s *membershipsStatements) UpsertMembership( - ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, + ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, streamPos, topologicalPos types.StreamPosition, ) error { membership, err := event.Membership() @@ -110,7 +109,7 @@ func (s *membershipsStatements) UpsertMembership( _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( ctx, event.RoomID(), - *event.StateKey(), + event.StateKeyResolved, membership, event.EventID(), streamPos, diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 3900ac3ae4..b58cf59f04 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/synctypes" @@ -136,15 +137,6 @@ FROM room_ids, ) AS x ` -const selectEarlyEventsSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3" + - " AND ( $4::text[] IS NULL OR sender = ANY($4) )" + - " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + - " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + - " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + - " ORDER BY id ASC LIMIT $8" - const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" @@ -206,7 +198,6 @@ type outputRoomEventsStatements struct { selectMaxEventIDStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt - selectEarlyEventsStmt *sql.Stmt selectStateInRangeFilteredStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt updateEventJSONStmt *sql.Stmt @@ -262,7 +253,6 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectRecentEventsStmt, selectRecentEventsSQL}, {&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL}, - {&s.selectEarlyEventsStmt, selectEarlyEventsSQL}, {&s.selectStateInRangeFilteredStmt, selectStateInRangeFilteredSQL}, {&s.selectStateInRangeStmt, selectStateInRangeSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL}, @@ -275,7 +265,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { }.Prepare(db) } -func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error { +func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent) error { headeredJSON, err := json.Marshal(event) if err != nil { return err @@ -340,8 +330,8 @@ func (s *outputRoomEventsStatements) SelectStateInRange( } // TODO: Handle redacted events - var ev gomatrixserverlib.HeaderedEvent - if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + var ev rstypes.HeaderedEvent + if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, nil, err } needSet := stateNeeded[ev.RoomID()] @@ -386,7 +376,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID( // of the inserted event. func (s *outputRoomEventsStatements) InsertEvent( ctx context.Context, txn *sql.Tx, - event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, + event *rstypes.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool, historyVisibility gomatrixserverlib.HistoryVisibility, ) (streamPos types.StreamPosition, err error) { var txnID *string @@ -417,7 +407,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event.EventID(), headeredJSON, event.Type(), - event.Sender(), + event.UserID.String(), containsURL, pq.StringArray(addState), pq.StringArray(removeState), @@ -476,8 +466,8 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( return nil, err } // TODO: Handle redacted events - var ev gomatrixserverlib.HeaderedEvent - if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + var ev rstypes.HeaderedEvent + if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } @@ -530,39 +520,6 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( return result, rows.Err() } -// selectEarlyEvents returns the earliest events in the given room, starting -// from a given position, up to a maximum of 'limit'. -func (s *outputRoomEventsStatements) SelectEarlyEvents( - ctx context.Context, txn *sql.Tx, - roomID string, r types.Range, eventFilter *synctypes.RoomEventFilter, -) ([]types.StreamEvent, error) { - senders, notSenders := getSendersRoomEventFilter(eventFilter) - stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) - rows, err := stmt.QueryContext( - ctx, roomID, r.Low(), r.High(), - pq.StringArray(senders), - pq.StringArray(notSenders), - pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), - pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), - eventFilter.Limit, - ) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, err - } - // The events need to be returned from oldest to latest, which isn't - // necessarily the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - return events, nil -} - // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( @@ -621,7 +578,7 @@ func (s *outputRoomEventsStatements) DeleteEventsForRoom( return err } -func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (id int, evt gomatrixserverlib.HeaderedEvent, err error) { +func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (id int, evt rstypes.HeaderedEvent, err error) { row := sqlutil.TxStmt(txn, s.selectContextEventStmt).QueryRowContext(ctx, roomID, eventID) var eventAsString string @@ -639,7 +596,7 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn func (s *outputRoomEventsStatements) SelectContextBeforeEvent( ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter, -) (evts []*gomatrixserverlib.HeaderedEvent, err error) { +) (evts []*rstypes.HeaderedEvent, err error) { senders, notSenders := getSendersRoomEventFilter(filter) rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext( ctx, roomID, id, filter.Limit, @@ -656,7 +613,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( for rows.Next() { var ( eventBytes []byte - evt *gomatrixserverlib.HeaderedEvent + evt *rstypes.HeaderedEvent historyVisibility gomatrixserverlib.HistoryVisibility ) if err = rows.Scan(&eventBytes, &historyVisibility); err != nil { @@ -674,7 +631,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( func (s *outputRoomEventsStatements) SelectContextAfterEvent( ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter, -) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) { +) (lastID int, evts []*rstypes.HeaderedEvent, err error) { senders, notSenders := getSendersRoomEventFilter(filter) rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext( ctx, roomID, id, filter.Limit, @@ -691,7 +648,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent( for rows.Next() { var ( eventBytes []byte - evt *gomatrixserverlib.HeaderedEvent + evt *rstypes.HeaderedEvent historyVisibility gomatrixserverlib.HistoryVisibility ) if err = rows.Scan(&lastID, &eventBytes, &historyVisibility); err != nil { @@ -724,8 +681,8 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { return nil, err } // TODO: Handle redacted events - var ev gomatrixserverlib.HeaderedEvent - if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + var ev rstypes.HeaderedEvent + if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } @@ -753,7 +710,7 @@ func (s *outputRoomEventsStatements) PurgeEvents( return err } -func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { +func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]rstypes.HeaderedEvent, error) { rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit) if err != nil { return nil, err @@ -762,14 +719,14 @@ func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, l var eventID string var id int64 - result := make(map[int64]gomatrixserverlib.HeaderedEvent) + result := make(map[int64]rstypes.HeaderedEvent) for rows.Next() { - var ev gomatrixserverlib.HeaderedEvent + var ev rstypes.HeaderedEvent var eventBytes []byte if err = rows.Scan(&id, &eventID, &eventBytes); err != nil { return nil, err } - if err = ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + if err = json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } result[id] = ev diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 2382fca5c5..7140a92fc1 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -18,10 +18,9 @@ import ( "context" "database/sql" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -105,7 +104,7 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { // InsertEventInTopology inserts the given event in the room's topology, based // on the event's depth. func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( - ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, + ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, pos types.StreamPosition, ) (topoPos types.StreamPosition, err error) { err = sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).QueryRowContext( ctx, event.EventID(), event.Depth(), event.RoomID(), pos, diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index e48718007b..37ee3faf32 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -20,7 +20,7 @@ import ( "time" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -127,7 +127,7 @@ func (p *presenceStatements) UpsertPresence( userID string, statusMsg *string, presence types.Presence, - lastActiveTS gomatrixserverlib.Timestamp, + lastActiveTS spec.Timestamp, fromSync bool, ) (pos types.StreamPosition, err error) { if fromSync { @@ -179,7 +179,7 @@ func (p *presenceStatements) GetPresenceAfter( ) (presences map[string]*types.PresenceInternal, err error) { presences = make(map[string]*types.PresenceInternal) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) - afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5)) + afterTS := spec.AsTimestamp(time.Now().Add(time.Minute * -5)) rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit) if err != nil { return nil, err diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index 0fcbebfcbd..9ab8eece0d 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -26,7 +26,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const receiptsSchema = ` @@ -98,7 +98,7 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { }.Prepare(db) } -func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { +func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) { stmt := sqlutil.TxStmt(txn, r.upsertReceipt) err = stmt.QueryRowContext(ctx, roomId, receiptType, userId, eventId, timestamp).Scan(&pos) return diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index e075833ed5..a3ffc8f786 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -23,7 +23,9 @@ import ( "github.com/tidwall/gjson" + rstypes "github.com/matrix-org/dendrite/roomserver/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" @@ -93,7 +95,7 @@ func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransac }, nil } -func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { +func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*rstypes.HeaderedEvent, error) { streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false) if err != nil { return nil, err @@ -101,14 +103,55 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixse // We don't include a device here as we only include transaction IDs in // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil + return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil +} + +func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent { + out := make([]*rstypes.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[i].HeaderedEvent + if device != nil && in[i].TransactionID != nil { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + roomID, err := spec.NewRoomID(in[i].RoomID()) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Room ID is invalid") + continue + } + deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID { + err := out[i].SetUnsignedField( + "transaction_id", in[i].TransactionID.TransactionID, + ) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + } + } + } + } + return out } // AddInviteEvent stores a new invite event for a user. // If the invite was successfully stored this returns the stream ID it was stored at. // Returns an error if there was a problem communicating with the database. func (d *Database) AddInviteEvent( - ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent, + ctx context.Context, inviteEvent *rstypes.HeaderedEvent, ) (sp types.StreamPosition, err error) { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) @@ -192,31 +235,11 @@ func (d *Database) UpsertAccountData( return } -func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent { - out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[i].HeaderedEvent - if device != nil && in[i].TransactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { - err := out[i].SetUnsignedField( - "transaction_id", in[i].TransactionID.TransactionID, - ) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - } - } - } - } - return out -} - // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. // This function should always be called within a sqlutil.Writer for safety in SQLite. -func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { +func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *rstypes.HeaderedEvent) error { if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { return err } @@ -249,8 +272,8 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e func (d *Database) WriteEvent( ctx context.Context, - ev *gomatrixserverlib.HeaderedEvent, - addStateEvents []*gomatrixserverlib.HeaderedEvent, + ev *rstypes.HeaderedEvent, + addStateEvents []*rstypes.HeaderedEvent, addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool, historyVisibility gomatrixserverlib.HistoryVisibility, @@ -291,7 +314,7 @@ func (d *Database) WriteEvent( func (d *Database) updateRoomState( ctx context.Context, txn *sql.Tx, removedEventIDs []string, - addedEvents []*gomatrixserverlib.HeaderedEvent, + addedEvents []*rstypes.HeaderedEvent, pduPosition types.StreamPosition, topoPosition types.StreamPosition, ) error { @@ -352,7 +375,7 @@ func (d *Database) PutFilter( return filterID, err } -func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error { +func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *rstypes.HeaderedEvent, querier api.QuerySenderIDAPI) error { redactedEvents, err := d.Events(ctx, []string{redactedEventID}) if err != nil { return err @@ -361,13 +384,13 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction") return nil } - eventToRedact := redactedEvents[0].Unwrap() - redactionEvent := redactedBecause.Unwrap() - if err = eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil { + eventToRedact := redactedEvents[0].PDU + redactionEvent := redactedBecause.PDU + if err = eventutil.RedactEvent(ctx, redactionEvent, eventToRedact, querier); err != nil { return err } - newEvent := eventToRedact.Headered(redactedBecause.RoomVersion) + newEvent := &rstypes.HeaderedEvent{PDU: eventToRedact} err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent) }) @@ -502,8 +525,24 @@ func (d *Database) CleanSendToDeviceUpdates( // getMembershipFromEvent returns the value of content.membership iff the event is a state event // with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. -func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) { - if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { +func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userID string, rsAPI api.SyncRoomserverAPI) (string, string) { + if ev.StateKey() == nil || *ev.StateKey() == "" { + return "", "" + } + fullUser, err := spec.NewUserID(userID, true) + if err != nil { + return "", "" + } + roomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return "", "" + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *fullUser) + if err != nil { + return "", "" + } + + if ev.Type() != "m.room.member" || !ev.StateKeyEquals(string(senderID)) { return "", "" } membership, err := ev.Membership() @@ -515,7 +554,7 @@ func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, } // StoreReceipt stores user receipts -func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { +func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp) return err @@ -531,14 +570,14 @@ func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userI return } -func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { +func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, rstypes.HeaderedEvent, error) { return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) } -func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { +func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) ([]*rstypes.HeaderedEvent, error) { return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) } -func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { +func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *synctypes.RoomEventFilter) (int, []*rstypes.HeaderedEvent, error) { return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) } @@ -552,7 +591,7 @@ func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, igno }) } -func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { +func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS spec.Timestamp, fromSync bool) (types.StreamPosition, error) { var pos types.StreamPosition var err error _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -570,17 +609,17 @@ func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID s return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) } -func (d *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) { +func (d *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]rstypes.HeaderedEvent, error) { return d.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{ - gomatrixserverlib.MRoomName, - gomatrixserverlib.MRoomTopic, + spec.MRoomName, + spec.MRoomTopic, "m.room.message", }) } -func (d *Database) UpdateRelations(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { +func (d *Database) UpdateRelations(ctx context.Context, event *rstypes.HeaderedEvent) error { // No need to unmarshal if the event is a redaction - if event.Type() == gomatrixserverlib.MRoomRedaction { + if event.Type() == spec.MRoomRedaction { return nil } var content gomatrixserverlib.RelationContent @@ -635,7 +674,7 @@ func (s *Database) UpdateLastActive(ctx context.Context, userId string, lastActi return s.Presence.UpdateLastActive(ctx, userId, lastActiveTs) } -func (d *Database) UpdateMultiRoomVisibility(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { +func (d *Database) UpdateMultiRoomVisibility(ctx context.Context, event *rstypes.HeaderedEvent) error { var mrdEv mrd.StateEvent err := json.Unmarshal(event.Content(), &mrdEv) if err != nil { @@ -643,7 +682,7 @@ func (d *Database) UpdateMultiRoomVisibility(ctx context.Context, event *gomatri } if mrdEv.Hidden { err = d.MultiRoomQ.DeleteMultiRoomVisibility(ctx, mrd.DeleteMultiRoomVisibilityParams{ - UserID: event.Sender(), + UserID: string(event.SenderID()), Type: event.Type(), RoomID: event.RoomID(), }) @@ -653,7 +692,7 @@ func (d *Database) UpdateMultiRoomVisibility(ctx context.Context, event *gomatri } if mrdEv.ExpireTs > 0 { err = d.MultiRoomQ.InsertMultiRoomVisibility(ctx, mrd.InsertMultiRoomVisibilityParams{ - UserID: event.Sender(), + UserID: string(event.SenderID()), Type: event.Type(), RoomID: event.RoomID(), ExpireTs: mrdEv.ExpireTs, diff --git a/syncapi/storage/shared/storage_consumer_test.go b/syncapi/storage/shared/storage_consumer_test.go new file mode 100644 index 0000000000..54a2ee88d1 --- /dev/null +++ b/syncapi/storage/shared/storage_consumer_test.go @@ -0,0 +1,103 @@ +package shared_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +func newSyncDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { + t.Helper() + + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + syncDB, _, err := storage.NewSyncServerDatasource(processCtx.Context(), cm, &cfg.SyncAPI.Database) + if err != nil { + t.Fatalf("failed to create sync DB: %s", err) + } + + return syncDB, closeDB +} + +func TestFilterTable(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := newSyncDB(t, dbType) + defer closeDB() + + // initially create a filter + filter := &synctypes.Filter{} + filterID, err := tab.PutFilter(context.Background(), "alice", filter) + if err != nil { + t.Fatal(err) + } + + // create the same filter again, we should receive the existing filter + secondFilterID, err := tab.PutFilter(context.Background(), "alice", filter) + if err != nil { + t.Fatal(err) + } + + if secondFilterID != filterID { + t.Fatalf("expected second filter to be the same as the first: %s vs %s", filterID, secondFilterID) + } + + // query the filter again + targetFilter := &synctypes.Filter{} + if err = tab.GetFilter(context.Background(), targetFilter, "alice", filterID); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(filter, targetFilter) { + t.Fatalf("%#v vs %#v", filter, targetFilter) + } + + // query non-existent filter + if err = tab.GetFilter(context.Background(), targetFilter, "bob", filterID); err == nil { + t.Fatalf("expected filter to not exist, but it does exist: %v", targetFilter) + } + }) +} + +func TestIgnores(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + syncDB, closeDB := newSyncDB(t, dbType) + defer closeDB() + + tab, err := syncDB.NewDatabaseTransaction(context.Background()) + if err != nil { + t.Fatal(err) + } + defer tab.Rollback() // nolint: errcheck + + ignoredUsers := &types.IgnoredUsers{List: map[string]interface{}{ + bob.ID: "", + }} + if err = tab.UpdateIgnoresForUser(context.Background(), alice.ID, ignoredUsers); err != nil { + t.Fatal(err) + } + + gotIgnoredUsers, err := tab.IgnoresForUser(context.Background(), alice.ID) + if err != nil { + t.Fatal(err) + } + + // verify the ignored users matches those we stored + if !reflect.DeepEqual(gotIgnoredUsers, ignoredUsers) { + t.Fatalf("%#v vs %#v", gotIgnoredUsers, ignoredUsers) + } + + // Bob doesn't have any ignored users, so should receive sql.ErrNoRows + if _, err = tab.IgnoresForUser(context.Background(), bob.ID); err == nil { + t.Fatalf("expected an error but got none") + } + }) +} diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index aa0eb06764..cd4a8bd6c0 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -6,10 +6,12 @@ import ( "fmt" "math" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -83,7 +85,7 @@ func (d *DatabaseTransaction) MaxStreamPositionForNotificationData(ctx context.C return types.StreamPosition(id), nil } -func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { +func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error) { return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs) } @@ -98,11 +100,11 @@ func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membe func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID string) (*types.Summary, error) { summary := &types.Summary{Heroes: []string{}} - joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Join) + joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, spec.Join) if err != nil { return summary, err } - inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Invite) + inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, spec.Invite) if err != nil { return summary, err } @@ -111,7 +113,7 @@ func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID // Get the room name and canonical alias, if any filter := synctypes.DefaultStateFilter() - filterTypes := []string{gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias} + filterTypes := []string{spec.MRoomName, spec.MRoomCanonicalAlias} filterRooms := []string{roomID} filter.Types = &filterTypes @@ -123,11 +125,11 @@ func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID for _, ev := range evs { switch ev.Type() { - case gomatrixserverlib.MRoomName: + case spec.MRoomName: if gjson.GetBytes(ev.Content(), "name").Str != "" { return summary, nil } - case gomatrixserverlib.MRoomCanonicalAlias: + case spec.MRoomCanonicalAlias: if gjson.GetBytes(ev.Content(), "alias").Str != "" { return summary, nil } @@ -135,14 +137,14 @@ func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID } // If there's no room name or canonical alias, get the room heroes, excluding the user - heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Join, gomatrixserverlib.Invite}) + heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{spec.Join, spec.Invite}) if err != nil { return summary, err } // "When no joined or invited members are available, this should consist of the banned and left users" if len(heroes) == 0 { - heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Leave, gomatrixserverlib.Ban}) + heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{spec.Leave, spec.Ban}) if err != nil { return summary, err } @@ -160,7 +162,7 @@ func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID st return d.Topology.SelectPositionInTopology(ctx, d.txn, eventID) } -func (d *DatabaseTransaction) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) { +func (d *DatabaseTransaction) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*rstypes.HeaderedEvent, map[string]*rstypes.HeaderedEvent, types.StreamPosition, error) { return d.Invites.SelectInviteEventsInRange(ctx, d.txn, targetUserID, r) } @@ -177,7 +179,7 @@ func (d *DatabaseTransaction) RoomReceiptsAfter(ctx context.Context, roomIDs []s // If an event is not found in the database then it will be omitted from the list. // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. -func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { +func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]*rstypes.HeaderedEvent, error) { streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, false) if err != nil { return nil, err @@ -185,7 +187,7 @@ func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([] // We don't include a device here as we only include transaction IDs in // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil + return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil } func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { @@ -206,13 +208,13 @@ func (d *DatabaseTransaction) SharedUsers(ctx context.Context, userID string, ot func (d *DatabaseTransaction) GetStateEvent( ctx context.Context, roomID, evType, stateKey string, -) (*gomatrixserverlib.HeaderedEvent, error) { +) (*rstypes.HeaderedEvent, error) { return d.CurrentRoomState.SelectStateEvent(ctx, d.txn, roomID, evType, stateKey) } func (d *DatabaseTransaction) GetStateEventsForRoom( ctx context.Context, roomID string, stateFilter *synctypes.StateFilter, -) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { +) (stateEvents []*rstypes.HeaderedEvent, err error) { stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil) return } @@ -301,7 +303,7 @@ func (d *DatabaseTransaction) StreamToTopologicalPosition( // oldest event in the room's topology. func (d *DatabaseTransaction) GetBackwardTopologyPos( ctx context.Context, - events []*gomatrixserverlib.HeaderedEvent, + events []*rstypes.HeaderedEvent, ) (types.TopologyToken, error) { zeroToken := types.TopologyToken{} if len(events) == 0 { @@ -324,7 +326,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos( func (d *DatabaseTransaction) GetStateDeltas( ctx context.Context, device *userapi.Device, r types.Range, userID string, - stateFilter *synctypes.StateFilter, + stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI, ) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response @@ -349,7 +351,7 @@ func (d *DatabaseTransaction) GetStateDeltas( joinedRoomIDs := make([]string, 0, len(memberships)) for roomID, membership := range memberships { allRoomIDs = append(allRoomIDs, roomID) - if membership == gomatrixserverlib.Join { + if membership == spec.Join { joinedRoomIDs = append(joinedRoomIDs, roomID) } } @@ -415,8 +417,8 @@ func (d *DatabaseTransaction) GetStateDeltas( } if !peek.Deleted { deltas = append(deltas, types.StateDelta{ - Membership: gomatrixserverlib.Peek, - StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), + Membership: spec.Peek, + StateEvents: d.StreamEventsToEvents(ctx, device, state[peek.RoomID], rsAPI), RoomID: peek.RoomID, }) } @@ -428,12 +430,12 @@ func (d *DatabaseTransaction) GetStateDeltas( for _, ev := range stateStreamEvents { // Look for our membership in the state events and skip over any // membership events that are not related to us. - membership, prevMembership := getMembershipFromEvent(ev.Event, userID) + membership, prevMembership := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI) if membership == "" { continue } - if membership == gomatrixserverlib.Join { + if membership == spec.Join { // If our membership is now join but the previous membership wasn't // then this is a "join transition", so we'll insert this room. if prevMembership != membership { @@ -461,7 +463,7 @@ func (d *DatabaseTransaction) GetStateDeltas( deltas = append(deltas, types.StateDelta{ Membership: membership, MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]), + StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[roomID], rsAPI), RoomID: roomID, }) break @@ -472,8 +474,8 @@ func (d *DatabaseTransaction) GetStateDeltas( // join transitions above. for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, types.StateDelta{ - Membership: gomatrixserverlib.Join, - StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]), + Membership: spec.Join, + StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[joinedRoomID], rsAPI), RoomID: joinedRoomID, NewlyJoined: newlyJoinedRooms[joinedRoomID], }) @@ -489,7 +491,7 @@ func (d *DatabaseTransaction) GetStateDeltas( func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( ctx context.Context, device *userapi.Device, r types.Range, userID string, - stateFilter *synctypes.StateFilter, + stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI, ) ([]types.StateDelta, []string, error) { // Look up all memberships for the user. We only care about rooms that a // user has ever interacted with — joined to, kicked/banned from, left. @@ -505,7 +507,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( joinedRoomIDs := make([]string, 0, len(memberships)) for roomID, membership := range memberships { allRoomIDs = append(allRoomIDs, roomID) - if membership == gomatrixserverlib.Join { + if membership == spec.Join { joinedRoomIDs = append(joinedRoomIDs, roomID) } } @@ -529,8 +531,8 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( return nil, nil, stateErr } deltas[peek.RoomID] = types.StateDelta{ - Membership: gomatrixserverlib.Peek, - StateEvents: d.StreamEventsToEvents(device, s), + Membership: spec.Peek, + StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI), RoomID: peek.RoomID, } } @@ -554,12 +556,12 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( for roomID, stateStreamEvents := range state { for _, ev := range stateStreamEvents { - if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" { - if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. + if membership, _ := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI); membership != "" { + if membership != spec.Join { // We've already added full state for all joined rooms above. deltas[roomID] = types.StateDelta{ Membership: membership, MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + StateEvents: d.StreamEventsToEvents(ctx, device, stateStreamEvents, rsAPI), RoomID: roomID, } } @@ -579,8 +581,8 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( return nil, nil, stateErr } deltas[joinedRoomID] = types.StateDelta{ - Membership: gomatrixserverlib.Join, - StateEvents: d.StreamEventsToEvents(device, s), + Membership: spec.Join, + StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI), RoomID: joinedRoomID, } } @@ -636,7 +638,7 @@ func (d *DatabaseTransaction) GetRoomReceipts(ctx context.Context, roomIDs []str func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) { roomIDs := make([]string, 0, len(rooms)) for roomID, membership := range rooms { - if membership != gomatrixserverlib.Join { + if membership != spec.Join { continue } roomIDs = append(roomIDs, roomID) diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index b49c2f7015..132bd80c8f 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -130,7 +130,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( if pos == 0 { pos = r.High() } - return data, pos, nil + return data, pos, rows.Err() } func (s *accountDataStatements) SelectMaxAccountDataID( diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index c681933d5d..3bd19b3676 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -25,11 +25,13 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const currentRoomStateSchema = ` @@ -267,12 +269,12 @@ func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, stateFilter *synctypes.StateFilter, excludeEventIDs []string, -) ([]*gomatrixserverlib.HeaderedEvent, error) { +) ([]*rstypes.HeaderedEvent, error) { // We're going to query members later, so remove them from this request if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers { - notTypes := &[]string{gomatrixserverlib.MRoomMember} + notTypes := &[]string{spec.MRoomMember} if stateFilter.NotTypes != nil { - *stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember) + *stateFilter.NotTypes = append(*stateFilter.NotTypes, spec.MRoomMember) } else { stateFilter.NotTypes = notTypes } @@ -318,7 +320,7 @@ func (s *currentRoomStateStatements) DeleteRoomStateForRoom( func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, - event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, + event *rstypes.HeaderedEvent, membership *string, addedAt types.StreamPosition, ) error { // Parse content as JSON and search for an "url" key containsURL := false @@ -340,9 +342,9 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.RoomID(), event.EventID(), event.Type(), - event.Sender(), + event.UserID.String(), containsURL, - *event.StateKey(), + *event.StateKeyResolved, headeredJSON, membership, addedAt, @@ -404,8 +406,8 @@ func currentRoomStateRowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, er return nil, err } // TODO: Handle redacted events - var ev gomatrixserverlib.HeaderedEvent - if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + var ev rstypes.HeaderedEvent + if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } @@ -420,8 +422,8 @@ func currentRoomStateRowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, er return events, nil } -func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { - result := []*gomatrixserverlib.HeaderedEvent{} +func rowsToEvents(rows *sql.Rows) ([]*rstypes.HeaderedEvent, error) { + result := []*rstypes.HeaderedEvent{} for rows.Next() { var eventID string var eventBytes []byte @@ -429,8 +431,8 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { return nil, err } // TODO: Handle redacted events - var ev gomatrixserverlib.HeaderedEvent - if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + var ev rstypes.HeaderedEvent + if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } result = append(result, &ev) @@ -440,7 +442,7 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { func (s *currentRoomStateStatements) SelectStateEvent( ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string, -) (*gomatrixserverlib.HeaderedEvent, error) { +) (*rstypes.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt) var res []byte err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) @@ -450,7 +452,7 @@ func (s *currentRoomStateStatements) SelectStateEvent( if err != nil { return nil, err } - var ev gomatrixserverlib.HeaderedEvent + var ev rstypes.HeaderedEvent if err = json.Unmarshal(res, &ev); err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go b/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go index d23f07566b..f7ce6531ec 100644 --- a/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go +++ b/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -91,7 +92,7 @@ func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gom defer rows.Close() // nolint: errcheck var eventBytes []byte var roomID string - var event gomatrixserverlib.HeaderedEvent + var event types.HeaderedEvent var hisVis gomatrixserverlib.HistoryVisibility historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility) for rows.Next() { diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 19450099af..7e0d895f12 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -22,9 +22,9 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const inviteEventsSchema = ` @@ -89,7 +89,7 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Inv } func (s *inviteEventsStatements) InsertInviteEvent( - ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent, + ctx context.Context, txn *sql.Tx, inviteEvent *rstypes.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { streamPos, err = s.streamIDStatements.nextInviteID(ctx, txn) if err != nil { @@ -108,7 +108,7 @@ func (s *inviteEventsStatements) InsertInviteEvent( streamPos, inviteEvent.RoomID(), inviteEvent.EventID(), - *inviteEvent.StateKey(), + inviteEvent.UserID.String(), headeredJSON, ) return @@ -130,7 +130,7 @@ func (s *inviteEventsStatements) DeleteInviteEvent( // active invites for the target user ID in the supplied range. func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, -) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) { +) (map[string]*rstypes.HeaderedEvent, map[string]*rstypes.HeaderedEvent, types.StreamPosition, error) { var lastPos types.StreamPosition stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) @@ -138,8 +138,8 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( return nil, nil, lastPos, err } defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") - result := map[string]*gomatrixserverlib.HeaderedEvent{} - retired := map[string]*gomatrixserverlib.HeaderedEvent{} + result := map[string]*rstypes.HeaderedEvent{} + retired := map[string]*rstypes.HeaderedEvent{} for rows.Next() { var ( id types.StreamPosition @@ -162,7 +162,7 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( continue } - var event *gomatrixserverlib.HeaderedEvent + var event *rstypes.HeaderedEvent if err := json.Unmarshal(eventJSON, &event); err != nil { return nil, nil, lastPos, err } diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 2cc46a10a2..a9e880d2a4 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -19,9 +19,8 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/sqlutil" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -103,7 +102,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { } func (s *membershipsStatements) UpsertMembership( - ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, + ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, streamPos, topologicalPos types.StreamPosition, ) error { membership, err := event.Membership() @@ -113,7 +112,7 @@ func (s *membershipsStatements) UpsertMembership( _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( ctx, event.RoomID(), - *event.StateKey(), + event.StateKeyResolved, membership, event.EventID(), streamPos, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 33ca687df1..06c65419af 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -25,11 +25,11 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -82,12 +82,6 @@ const selectRecentEventsForSyncSQL = "" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters -const selectEarlyEventsSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3" - -// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters - const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" @@ -119,7 +113,7 @@ const selectContextAfterEventSQL = "" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters -const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC" +const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type IN ($2)" const purgeEventsSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" @@ -173,7 +167,7 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even }.Prepare(db) } -func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error { +func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent) error { headeredJSON, err := json.Marshal(event) if err != nil { return err @@ -256,8 +250,8 @@ func (s *outputRoomEventsStatements) SelectStateInRange( } // TODO: Handle redacted events - var ev gomatrixserverlib.HeaderedEvent - if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + var ev rstypes.HeaderedEvent + if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, nil, err } needSet := stateNeeded[ev.RoomID()] @@ -303,7 +297,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID( // of the inserted event. func (s *outputRoomEventsStatements) InsertEvent( ctx context.Context, txn *sql.Tx, - event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, + event *rstypes.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool, historyVisibility gomatrixserverlib.HistoryVisibility, ) (types.StreamPosition, error) { var txnID *string @@ -354,7 +348,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event.EventID(), headeredJSON, event.Type(), - event.Sender(), + event.UserID.String(), containsURL, string(addStateJSON), string(removeStateJSON), @@ -430,42 +424,6 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( return result, nil } -func (s *outputRoomEventsStatements) SelectEarlyEvents( - ctx context.Context, txn *sql.Tx, - roomID string, r types.Range, eventFilter *synctypes.RoomEventFilter, -) ([]types.StreamEvent, error) { - stmt, params, err := prepareWithFilters( - s.db, txn, selectEarlyEventsSQL, - []interface{}{ - roomID, r.Low(), r.High(), - }, - eventFilter.Senders, eventFilter.NotSenders, - eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc, - ) - if err != nil { - return nil, fmt.Errorf("s.prepareWithFilters: %w", err) - } - defer internal.CloseAndLogIfError(ctx, stmt, "SelectEarlyEvents: stmt.close() failed") - - rows, err := stmt.QueryContext(ctx, params...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, err - } - // The events need to be returned from oldest to latest, which isn't - // necessarily the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - return events, nil -} - // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( @@ -541,8 +499,8 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { return nil, err } // TODO: Handle redacted events - var ev gomatrixserverlib.HeaderedEvent - if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + var ev rstypes.HeaderedEvent + if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } @@ -566,7 +524,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { } func (s *outputRoomEventsStatements) SelectContextEvent( ctx context.Context, txn *sql.Tx, roomID, eventID string, -) (id int, evt gomatrixserverlib.HeaderedEvent, err error) { +) (id int, evt rstypes.HeaderedEvent, err error) { row := sqlutil.TxStmt(txn, s.selectContextEventStmt).QueryRowContext(ctx, roomID, eventID) var eventAsString string var historyVisibility gomatrixserverlib.HistoryVisibility @@ -583,7 +541,7 @@ func (s *outputRoomEventsStatements) SelectContextEvent( func (s *outputRoomEventsStatements) SelectContextBeforeEvent( ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter, -) (evts []*gomatrixserverlib.HeaderedEvent, err error) { +) (evts []*rstypes.HeaderedEvent, err error) { stmt, params, err := prepareWithFilters( s.db, txn, selectContextBeforeEventSQL, []interface{}{ @@ -607,7 +565,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( for rows.Next() { var ( eventBytes []byte - evt *gomatrixserverlib.HeaderedEvent + evt *rstypes.HeaderedEvent historyVisibility gomatrixserverlib.HistoryVisibility ) if err = rows.Scan(&eventBytes, &historyVisibility); err != nil { @@ -625,7 +583,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( func (s *outputRoomEventsStatements) SelectContextAfterEvent( ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter, -) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) { +) (lastID int, evts []*rstypes.HeaderedEvent, err error) { stmt, params, err := prepareWithFilters( s.db, txn, selectContextAfterEventSQL, []interface{}{ @@ -649,7 +607,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent( for rows.Next() { var ( eventBytes []byte - evt *gomatrixserverlib.HeaderedEvent + evt *rstypes.HeaderedEvent historyVisibility gomatrixserverlib.HistoryVisibility ) if err = rows.Scan(&lastID, &eventBytes, &historyVisibility); err != nil { @@ -685,19 +643,19 @@ func (s *outputRoomEventsStatements) PurgeEvents( return err } -func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { - params := make([]interface{}, len(types)) +func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]rstypes.HeaderedEvent, error) { + params := make([]interface{}, len(types)+1) + params[0] = afterID for i := range types { - params[i] = types[i] + params[i+1] = types[i] } - params = append(params, afterID) - params = append(params, limit) - selectSQL := strings.Replace(selectSearchSQL, "($1)", sqlutil.QueryVariadic(len(types)), 1) - stmt, err := s.db.Prepare(selectSQL) + selectSQL := strings.Replace(selectSearchSQL, "($2)", sqlutil.QueryVariadicOffset(len(types), 1), 1) + stmt, params, err := prepareWithFilters(s.db, txn, selectSQL, params, nil, nil, nil, nil, nil, nil, int(limit), FilterOrderAsc) if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, stmt, "selectEvents: stmt.close() failed") rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) if err != nil { @@ -707,14 +665,14 @@ func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, l var eventID string var id int64 - result := make(map[int64]gomatrixserverlib.HeaderedEvent) + result := make(map[int64]rstypes.HeaderedEvent) for rows.Next() { - var ev gomatrixserverlib.HeaderedEvent + var ev rstypes.HeaderedEvent var eventBytes []byte if err = rows.Scan(&id, &eventID, &eventBytes); err != nil { return nil, err } - if err = ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + if err = json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } result[id] = ev diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index dc698de2d3..68b75f5b18 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -18,9 +18,8 @@ import ( "context" "database/sql" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/sqlutil" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -104,7 +103,7 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { // insertEventInTopology inserts the given event in the room's topology, based // on the event's depth. func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( - ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, + ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, pos types.StreamPosition, ) (types.StreamPosition, error) { _, err := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).ExecContext( ctx, event.EventID(), event.Depth(), event.RoomID(), pos, diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index bddfe3d79b..3a1f1b4e8f 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -20,7 +20,7 @@ import ( "strings" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -113,7 +113,7 @@ func (p *presenceStatements) UpsertPresence( userID string, statusMsg *string, presence types.Presence, - lastActiveTS gomatrixserverlib.Timestamp, + lastActiveTS spec.Timestamp, fromSync bool, ) (pos types.StreamPosition, err error) { pos, err = p.streamIDStatements.nextPresenceID(ctx, txn) @@ -185,7 +185,7 @@ func (p *presenceStatements) GetPresenceAfter( ) (presences map[string]*types.PresenceInternal, err error) { presences = make(map[string]*types.PresenceInternal) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) - afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5)) + afterTS := spec.AsTimestamp(time.Now().Add(time.Minute * -5)) rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index ca3d80fb45..b973903bd6 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -25,7 +25,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const receiptsSchema = ` @@ -97,7 +97,7 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re } // UpsertReceipt creates new user receipts -func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { +func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) { pos, err = r.streamIDStatements.nextReceiptID(ctx, txn) if err != nil { return diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 2da722b2b1..992466450f 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -10,13 +10,17 @@ import ( "testing" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" ) var ctx = context.Background() @@ -33,12 +37,13 @@ func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun return db, close } -func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { +func MustWriteEvents(t *testing.T, db storage.Database, events []*rstypes.HeaderedEvent) (positions []types.StreamPosition) { for _, ev := range events { - var addStateEvents []*gomatrixserverlib.HeaderedEvent + var addStateEvents []*rstypes.HeaderedEvent var addStateEventIDs []string var removeStateEventIDs []string if ev.StateKey() != nil { + ev.StateKeyResolved = ev.StateKey() addStateEvents = append(addStateEvents, ev) addStateEventIDs = append(addStateEventIDs, ev.EventID()) } @@ -105,7 +110,7 @@ func TestRecentEventsPDU(t *testing.T) { To types.StreamPosition Limit int ReverseOrder bool - WantEvents []*gomatrixserverlib.HeaderedEvent + WantEvents []*rstypes.HeaderedEvent WantLimited bool }{ // The purpose of this test is to make sure that incremental syncs are including up to the latest events. @@ -212,7 +217,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { if err != nil { t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) } - gots := snapshot.StreamEventsToEvents(nil, paginatedEvents) + gots := snapshot.StreamEventsToEvents(context.Background(), nil, paginatedEvents, nil) test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) }) }) @@ -315,7 +320,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { t.Parallel() db := MustCreateDatabase(t) - var events []*gomatrixserverlib.HeaderedEvent + var events []*types.HeaderedEvent events = append(events, MustCreateEvent(t, testRoomID, nil, &gomatrixserverlib.EventBuilder{ Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, testUserIDA)), Type: "m.room.create", @@ -323,7 +328,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { Sender: testUserIDA, Depth: int64(len(events) + 1), })) - events = append(events, MustCreateEvent(t, testRoomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + events = append(events, MustCreateEvent(t, testRoomID, []*types.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ Content: []byte(`{"membership":"join"}`), Type: "m.room.member", StateKey: &testUserIDA, @@ -331,7 +336,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { Depth: int64(len(events) + 1), })) // fork the dag into three, same prev_events and depth - parent := []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]} + parent := []*types.HeaderedEvent{events[len(events)-1]} depth := int64(len(events) + 1) for i := 0; i < 3; i++ { events = append(events, MustCreateEvent(t, testRoomID, parent, &gomatrixserverlib.EventBuilder{ @@ -364,7 +369,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { Name string From types.TopologyToken Limit int - Wants []*gomatrixserverlib.HeaderedEvent + Wants []*types.HeaderedEvent }{ { Name: "Pagination over the whole fork", @@ -405,7 +410,7 @@ func TestGetEventsInTopologicalRangeMultiRoom(t *testing.T) { t.Parallel() db := MustCreateDatabase(t) - makeEvents := func(roomID string) (events []*gomatrixserverlib.HeaderedEvent) { + makeEvents := func(roomID string) (events []*types.HeaderedEvent) { events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{ Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, testUserIDA)), Type: "m.room.create", @@ -413,7 +418,7 @@ func TestGetEventsInTopologicalRangeMultiRoom(t *testing.T) { Sender: testUserIDA, Depth: int64(len(events) + 1), })) - events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + events = append(events, MustCreateEvent(t, roomID, []*types.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ Content: []byte(`{"membership":"join"}`), Type: "m.room.member", StateKey: &testUserIDA, @@ -459,14 +464,14 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) { // "federation" join userC := fmt.Sprintf("@radiance:%s", testOrigin) - joinEvent := MustCreateEvent(t, testRoomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + joinEvent := MustCreateEvent(t, testRoomID, []*types.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ Content: []byte(`{"membership":"join"}`), Type: "m.room.member", StateKey: &userC, Sender: userC, Depth: int64(len(events) + 1), }) - MustWriteEvents(t, db, []*gomatrixserverlib.HeaderedEvent{joinEvent}) + MustWriteEvents(t, db, []*types.HeaderedEvent{joinEvent}) // Sync will return this for the prev_batch from := topologyTokenBefore(t, db, joinEvent.EventID()) @@ -637,7 +642,7 @@ func TestInviteBehaviour(t *testing.T) { StateKey: &testUserIDA, Sender: "@inviteUser2:somewhere", }) - for _, ev := range []*gomatrixserverlib.HeaderedEvent{inviteEvent1, inviteEvent2} { + for _, ev := range []*types.HeaderedEvent{inviteEvent1, inviteEvent2} { _, err := db.AddInviteEvent(ctx, ev) if err != nil { t.Fatalf("Failed to AddInviteEvent: %s", err) @@ -694,7 +699,7 @@ func assertInvitedToRooms(t *testing.T, res *types.Response, roomIDs []string) { } } -func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []*gomatrixserverlib.HeaderedEvent) { +func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []*types.HeaderedEvent) { t.Helper() if len(gots) != len(wants) { t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) @@ -777,7 +782,7 @@ func TestRoomSummary(t *testing.T) { name: "invited user", wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) }, @@ -786,10 +791,10 @@ func TestRoomSummary(t *testing.T) { name: "invited user, but declined", wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(bob.ID)) }, @@ -798,10 +803,10 @@ func TestRoomSummary(t *testing.T) { name: "joined user after invitation", wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) }, @@ -810,10 +815,10 @@ func TestRoomSummary(t *testing.T) { name: "multiple joined user", wantSummary: &types.Summary{JoinedMemberCount: pointer(3), InvitedMemberCount: pointer(0), Heroes: []string{charlie.ID, bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(charlie.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) }, @@ -822,10 +827,10 @@ func TestRoomSummary(t *testing.T) { name: "multiple joined/invited user", wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID, bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(charlie.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) }, @@ -834,13 +839,13 @@ func TestRoomSummary(t *testing.T) { name: "multiple joined/invited/left user", wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "invite", }, test.WithStateKey(charlie.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(bob.ID)) }, @@ -849,10 +854,10 @@ func TestRoomSummary(t *testing.T) { name: "leaving user after joining", wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(bob.ID)) }, @@ -862,7 +867,7 @@ func TestRoomSummary(t *testing.T) { wantSummary: &types.Summary{JoinedMemberCount: pointer(len(moreUserIDs) + 1), InvitedMemberCount: pointer(0), Heroes: moreUserIDs[:5]}, additionalEvents: func(t *testing.T, room *test.Room) { for _, x := range moreUsers { - room.CreateAndInsert(t, x, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, x, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(x.ID)) } @@ -872,10 +877,10 @@ func TestRoomSummary(t *testing.T) { name: "canonical alias set", wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomCanonicalAlias, map[string]interface{}{ "alias": "myalias", }, test.WithStateKey("")) }, @@ -884,10 +889,10 @@ func TestRoomSummary(t *testing.T) { name: "room name set", wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomName, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomName, map[string]interface{}{ "name": "my room name", }, test.WithStateKey("")) }, @@ -976,3 +981,52 @@ func TestRecentEvents(t *testing.T) { } }) } + +type FakeQuerier struct { + api.QuerySenderIDAPI +} + +func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + +func TestRedaction(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) + redactionEvent := room.CreateEvent(t, alice, spec.MRoomRedaction, map[string]string{"redacts": redactedEvent.EventID()}) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := MustCreateDatabase(t, dbType) + t.Cleanup(close) + MustWriteEvents(t, db, room.Events()) + + err := db.RedactEvent(context.Background(), redactedEvent.EventID(), redactionEvent, &FakeQuerier{}) + if err != nil { + t.Fatal(err) + } + + evs, err := db.Events(context.Background(), []string{redactedEvent.EventID()}) + if err != nil { + t.Fatal(err) + } + + if len(evs) != 1 { + t.Fatalf("expected 1 event, got %d", len(evs)) + } + + // check a few fields which shouldn't be there in unsigned + authEvs := gjson.GetBytes(evs[0].Unsigned(), "redacted_because.auth_events") + if authEvs.Exists() { + t.Error("unexpected auth_events in redacted event") + } + prevEvs := gjson.GetBytes(evs[0].Unsigned(), "redacted_because.prev_events") + if prevEvs.Exists() { + t.Error("unexpected auth_events in redacted event") + } + depth := gjson.GetBytes(evs[0].Unsigned(), "redacted_because.depth") + if depth.Exists() { + t.Error("unexpected auth_events in redacted event") + } + }) +} diff --git a/syncapi/storage/tables/current_room_state_test.go b/syncapi/storage/tables/current_room_state_test.go index 5fe06c3cec..2df111a267 100644 --- a/syncapi/storage/tables/current_room_state_test.go +++ b/syncapi/storage/tables/current_room_state_test.go @@ -14,7 +14,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) { @@ -54,7 +54,13 @@ func TestCurrentRoomStateTable(t *testing.T) { events := room.CurrentState() err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { for i, ev := range events { - err := tab.UpsertRoomState(ctx, txn, ev, nil, types.StreamPosition(i)) + ev.StateKeyResolved = ev.StateKey() + userID, err := spec.NewUserID(string(ev.SenderID()), true) + if err != nil { + return err + } + ev.UserID = *userID + err = tab.UpsertRoomState(ctx, txn, ev, nil, types.StreamPosition(i)) if err != nil { return fmt.Errorf("failed to UpsertRoomState: %w", err) } @@ -115,7 +121,7 @@ func testCurrentState(t *testing.T, ctx context.Context, txn *sql.Tx, tab tables t.Fatalf("expected %d state events, got %d", expectCount, gotCount) } // same as above, but with existing NotTypes defined - notTypes := []string{gomatrixserverlib.MRoomMember} + notTypes := []string{spec.MRoomMember} filter.NotTypes = ¬Types evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) if err != nil { diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 6065d39c74..5085354e45 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -19,9 +19,11 @@ import ( "database/sql" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -34,11 +36,11 @@ type AccountData interface { } type Invites interface { - InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error) + InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent *rstypes.HeaderedEvent) (streamPos types.StreamPosition, err error) DeleteInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string) (types.StreamPosition, error) // SelectInviteEventsInRange returns a map of room ID to invite events. If multiple invite/retired invites exist in the given range, return the latest value // for the room. - SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error) + SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*rstypes.HeaderedEvent, retired map[string]*rstypes.HeaderedEvent, maxID types.StreamPosition, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) PurgeInvites(ctx context.Context, txn *sql.Tx, roomID string) error } @@ -58,7 +60,7 @@ type Events interface { SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error) InsertEvent( ctx context.Context, txn *sql.Tx, - event *gomatrixserverlib.HeaderedEvent, + event *rstypes.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool, @@ -68,19 +70,17 @@ type Events interface { // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`. SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomIDs []string, r types.Range, eventFilter *synctypes.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) - // SelectEarlyEvents returns the earliest events in the given room. - SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *synctypes.RoomEventFilter) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *synctypes.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) - UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error + UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent) error // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) - SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) - SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) - SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, rstypes.HeaderedEvent, error) + SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter) ([]*rstypes.HeaderedEvent, error) + SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *synctypes.RoomEventFilter) (int, []*rstypes.HeaderedEvent, error) PurgeEvents(ctx context.Context, txn *sql.Tx, roomID string) error - ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) + ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]rstypes.HeaderedEvent, error) } // Topology keeps track of the depths and stream positions for all events. @@ -88,7 +88,7 @@ type Events interface { type Topology interface { // InsertEventInTopology inserts the given event in the room's topology, based on the event's depth. // `pos` is the stream position of this event in the events table, and is used to order events which have the same depth. - InsertEventInTopology(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) (topoPos types.StreamPosition, err error) + InsertEventInTopology(ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, pos types.StreamPosition) (topoPos types.StreamPosition, err error) // SelectEventIDsInRange selects the IDs of events whose depths are within a given range in a given room's topological order. // Events with `minDepth` are *exclusive*, as is the event which has exactly `minDepth`,`maxStreamPos`. // `maxStreamPos` is only used when events have the same depth as `maxDepth`, which results in events less than `maxStreamPos` being returned. @@ -102,13 +102,13 @@ type Topology interface { } type CurrentRoomState interface { - SelectStateEvent(ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + SelectStateEvent(ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string) (*rstypes.HeaderedEvent, error) SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) - UpsertRoomState(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error + UpsertRoomState(ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, membership *string, addedAt types.StreamPosition) error DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error DeleteRoomStateForRoom(ctx context.Context, txn *sql.Tx, roomID string) error // SelectCurrentState returns all the current state events for the given room. - SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *synctypes.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) + SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error) // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) // SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. @@ -185,14 +185,14 @@ type Filter interface { } type Receipts interface { - UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) + UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp spec.Timestamp) (pos types.StreamPosition, err error) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) PurgeReceipts(ctx context.Context, txn *sql.Tx, roomID string) error } type Memberships interface { - UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error + UpsertMembership(ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) PurgeMemberships(ctx context.Context, txn *sql.Tx, roomID string) error @@ -216,7 +216,7 @@ type Ignores interface { } type Presence interface { - UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) + UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS spec.Timestamp, fromSync bool) (pos types.StreamPosition, err error) GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter synctypes.EventFilter) (presences map[string]*types.PresenceInternal, err error) diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go index df593ae781..a421a97727 100644 --- a/syncapi/storage/tables/memberships_test.go +++ b/syncapi/storage/tables/memberships_test.go @@ -6,9 +6,10 @@ import ( "testing" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/sqlutil" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/postgres" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" @@ -46,7 +47,7 @@ func TestMembershipsTable(t *testing.T) { room := test.NewRoom(t, alice) // Create users - var userEvents []*gomatrixserverlib.HeaderedEvent + var userEvents []*rstypes.HeaderedEvent users := []string{alice.ID} for _, x := range room.CurrentState() { if x.StateKeyEquals(alice.ID) { @@ -65,7 +66,7 @@ func TestMembershipsTable(t *testing.T) { u := test.NewUser(t) users = append(users, u.ID) - ev := room.CreateAndInsert(t, u, gomatrixserverlib.MRoomMember, map[string]interface{}{ + ev := room.CreateAndInsert(t, u, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(u.ID)) userEvents = append(userEvents, ev) @@ -79,6 +80,7 @@ func TestMembershipsTable(t *testing.T) { defer cancel() for _, ev := range userEvents { + ev.StateKeyResolved = ev.StateKey() if err := table.UpsertMembership(ctx, nil, ev, types.StreamPosition(ev.Depth()), 1); err != nil { t.Fatalf("failed to upsert membership: %s", err) } @@ -92,7 +94,7 @@ func TestMembershipsTable(t *testing.T) { func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) { t.Run("membership counts are correct", func(t *testing.T) { // After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users) - count, err := table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 10) + count, err := table.SelectMembershipCount(ctx, nil, room.ID, spec.Join, 10) if err != nil { t.Fatalf("failed to get membership count: %s", err) } @@ -102,7 +104,7 @@ func testMembershipCount(t *testing.T, ctx context.Context, table tables.Members } // After 100 events, we should have all 11 users - count, err = table.SelectMembershipCount(ctx, nil, room.ID, gomatrixserverlib.Join, 100) + count, err = table.SelectMembershipCount(ctx, nil, room.ID, spec.Join, 100) if err != nil { t.Fatalf("failed to get membership count: %s", err) } @@ -113,7 +115,7 @@ func testMembershipCount(t *testing.T, ctx context.Context, table tables.Members }) } -func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, membershipEvent *gomatrixserverlib.HeaderedEvent, user *test.User, room *test.Room) { +func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, membershipEvent *rstypes.HeaderedEvent, user *test.User, room *test.Room) { t.Run("upserting works as expected", func(t *testing.T) { if err := table.UpsertMembership(ctx, nil, membershipEvent, 1, 1); err != nil { t.Fatalf("failed to upsert membership: %s", err) @@ -126,13 +128,14 @@ func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, mem if pos != expectedPos { t.Fatalf("expected pos to be %d, got %d", expectedPos, pos) } - if membership != gomatrixserverlib.Join { + if membership != spec.Join { t.Fatalf("expected membership to be join, got %s", membership) } // Create a new event which gets upserted and should not cause issues - ev := room.CreateAndInsert(t, user, gomatrixserverlib.MRoomMember, map[string]interface{}{ - "membership": gomatrixserverlib.Join, + ev := room.CreateAndInsert(t, user, spec.MRoomMember, map[string]interface{}{ + "membership": spec.Join, }, test.WithStateKey(user.ID)) + ev.StateKeyResolved = ev.StateKey() // Insert the same event again, but with different positions, which should get updated if err = table.UpsertMembership(ctx, nil, ev, 2, 2); err != nil { t.Fatalf("failed to upsert membership: %s", err) @@ -147,7 +150,7 @@ func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, mem if pos != expectedPos { t.Fatalf("expected pos to be %d, got %d", expectedPos, pos) } - if membership != gomatrixserverlib.Join { + if membership != spec.Join { t.Fatalf("expected membership to be join, got %s", membership) } @@ -155,7 +158,7 @@ func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, mem if membership, _, err = table.SelectMembershipForUser(ctx, nil, room.ID, user.ID, 1); err != nil { t.Fatalf("failed to select membership: %s", err) } - if membership != gomatrixserverlib.Leave { + if membership != spec.Leave { t.Fatalf("expected membership to be leave, got %s", membership) } }) diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go index c0d4511115..9b755dc85d 100644 --- a/syncapi/storage/tables/output_room_events_test.go +++ b/syncapi/storage/tables/output_room_events_test.go @@ -15,6 +15,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) { @@ -104,3 +105,53 @@ func TestOutputRoomEventsTable(t *testing.T) { } }) } + +func TestReindex(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + room.CreateAndInsert(t, alice, spec.MRoomName, map[string]interface{}{ + "name": "my new room name", + }, test.WithStateKey("")) + + room.CreateAndInsert(t, alice, spec.MRoomTopic, map[string]interface{}{ + "topic": "my new room topic", + }, test.WithStateKey("")) + + room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{ + "msgbody": "my room message", + "type": "m.text", + }) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := newOutputRoomEventsTable(t, dbType) + defer close() + err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + for _, ev := range room.Events() { + _, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false, gomatrixserverlib.HistoryVisibilityShared) + if err != nil { + return fmt.Errorf("failed to InsertEvent: %s", err) + } + } + + return nil + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + events, err := tab.ReIndex(ctx, nil, 10, 0, []string{ + spec.MRoomName, + spec.MRoomTopic, + "m.room.message"}) + if err != nil { + t.Fatal(err) + } + + wantEventCount := 3 + if len(events) != wantEventCount { + t.Fatalf("expected %d events, got %d", wantEventCount, len(events)) + } + }) +} diff --git a/syncapi/storage/tables/presence_table_test.go b/syncapi/storage/tables/presence_table_test.go index cb7a4dee99..d8161836b1 100644 --- a/syncapi/storage/tables/presence_table_test.go +++ b/syncapi/storage/tables/presence_table_test.go @@ -7,8 +7,6 @@ import ( "testing" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/postgres" @@ -17,6 +15,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib/spec" ) func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) { @@ -52,7 +51,7 @@ func TestPresence(t *testing.T) { ctx := context.Background() statusMsg := "Hello World!" - timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + timestamp := spec.AsTimestamp(time.Now()) var txn *sql.Tx test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 22953b8c1c..51f2a3d30f 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -3,12 +3,11 @@ package streams import ( "context" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type AccountDataStreamProvider struct { @@ -85,7 +84,7 @@ func (p *AccountDataStreamProvider) IncrementalSync( req.Response.AccountData.Events, synctypes.ClientEvent{ Type: dataType, - Content: gomatrixserverlib.RawJSON(globalData), + Content: spec.RawJSON(globalData), }, ) } @@ -99,7 +98,7 @@ func (p *AccountDataStreamProvider) IncrementalSync( joinData.AccountData.Events, synctypes.ClientEvent{ Type: dataType, - Content: gomatrixserverlib.RawJSON(roomData), + Content: spec.RawJSON(roomData), }, ) req.Response.Rooms.Join[roomID] = joinData diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index a4414f3154..7c29d84ae0 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -8,8 +8,9 @@ import ( "strconv" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" @@ -17,6 +18,7 @@ import ( type InviteStreamProvider struct { DefaultStreamProvider + rsAPI api.SyncRoomserverAPI } func (p *InviteStreamProvider) Setup( @@ -62,11 +64,30 @@ func (p *InviteStreamProvider) IncrementalSync( } for roomID, inviteEvent := range invites { + user := spec.UserID{} + validRoomID, err := spec.NewRoomID(inviteEvent.RoomID()) + if err != nil { + continue + } + sender, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, inviteEvent.SenderID()) + if err == nil && sender != nil { + user = *sender + } + + sk := inviteEvent.StateKey() + if sk != nil && *sk != "" { + skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + // skip ignored user events - if _, ok := req.IgnoredUsers.List[inviteEvent.Sender()]; ok { + if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent) + ir := types.NewInviteResponse(inviteEvent, user, sk) req.Response.Rooms.Invite[roomID] = ir } @@ -79,7 +100,7 @@ func (p *InviteStreamProvider) IncrementalSync( membership, _, err := snapshot.SelectMembershipForUser(ctx, roomID, req.Device.UserID, math.MaxInt64) // Skip if the user is an existing member of the room. // Otherwise, the NewLeaveResponse will eject the user from the room unintentionally - if membership == gomatrixserverlib.Join || + if membership == spec.Join || err != nil { continue } @@ -89,12 +110,12 @@ func (p *InviteStreamProvider) IncrementalSync( lr.Timeline.Events = append(lr.Timeline.Events, synctypes.ClientEvent{ // fake event ID which muxes in the to position EventID: "$" + base64.RawURLEncoding.EncodeToString(h[:]), - OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), + OriginServerTS: spec.AsTimestamp(time.Now()), RoomID: roomID, Sender: req.Device.UserID, StateKey: &req.Device.UserID, Type: "m.room.member", - Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), + Content: spec.RawJSON(`{"membership":"leave"}`), }) req.Response.Rooms.Leave[roomID] = lr } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index cf014714a9..5bc0dac625 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -3,16 +3,21 @@ package streams import ( "context" "database/sql" + "encoding/json" "fmt" "time" "github.com/matrix-org/dendrite/internal/caching" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/gomatrixserverlib" @@ -70,7 +75,7 @@ func (p *PDUStreamProvider) CompleteSync( } // Extract room state and recent events for all rooms the user is joined to. - joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join) + joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, spec.Join) if err != nil { req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed") return from @@ -111,7 +116,7 @@ func (p *PDUStreamProvider) CompleteSync( continue } req.Response.Rooms.Join[roomID] = jr - req.Rooms[roomID] = gomatrixserverlib.Join + req.Rooms[roomID] = spec.Join } // Add peeked rooms. @@ -174,19 +179,19 @@ func (p *PDUStreamProvider) IncrementalSync( eventFilter := req.Filter.Room.Timeline if req.WantFullState { - if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") return from } } else { - if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") return from } } for _, roomID := range syncJoinedRooms { - req.Rooms[roomID] = gomatrixserverlib.Join + req.Rooms[roomID] = spec.Join } req.JoinedRooms = syncJoinedRooms @@ -274,10 +279,14 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( recentStreamEvents := dbEvents[delta.RoomID].Events limited := dbEvents[delta.RoomID].Limited - recentEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( - snapshot.StreamEventsToEvents(device, recentStreamEvents), + recEvents := gomatrixserverlib.ReverseTopologicalOrdering( + gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI)), gomatrixserverlib.TopologicalOrderByPrevEvents, ) + recentEvents := make([]*rstypes.HeaderedEvent, len(recEvents)) + for i := range recEvents { + recentEvents[i] = recEvents[i].(*rstypes.HeaderedEvent) + } // If we didn't return any events at all then don't bother doing anything else. if len(recentEvents) == 0 && len(delta.StateEvents) == 0 { @@ -314,8 +323,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( hasMembershipChange := false for _, recentEvent := range recentStreamEvents { - if recentEvent.Type() == gomatrixserverlib.MRoomMember && recentEvent.StateKey() != nil { - if membership, _ := recentEvent.Membership(); membership == gomatrixserverlib.Join { + if recentEvent.Type() == spec.MRoomMember && recentEvent.StateKey() != nil { + if membership, _ := recentEvent.Membership(); membership == spec.Join { req.MembershipChanges[*recentEvent.StateKey()] = struct{}{} } hasMembershipChange = true @@ -342,10 +351,41 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // Now that we've filtered the timeline, work out which state events are still // left. Anything that appears in the filtered timeline will be removed from the // "state" section and kept in "timeline". - delta.StateEvents = gomatrixserverlib.HeaderedReverseTopologicalOrdering( - removeDuplicates(delta.StateEvents, events), + + // update the powerlevel event for timeline events + for i, ev := range events { + if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { + continue + } + if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { + continue + } + var newEvent gomatrixserverlib.PDU + newEvent, err = p.updatePowerLevelEvent(ctx, ev) + if err != nil { + return r.From, err + } + events[i] = &rstypes.HeaderedEvent{PDU: newEvent} + } + + sEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( + gomatrixserverlib.ToPDUs(removeDuplicates(delta.StateEvents, events)), gomatrixserverlib.TopologicalOrderByAuthEvents, ) + delta.StateEvents = make([]*rstypes.HeaderedEvent, len(sEvents)) + for i := range sEvents { + ev := sEvents[i] + delta.StateEvents[i] = ev.(*rstypes.HeaderedEvent) + // update the powerlevel event for state events + if ev.Version() == gomatrixserverlib.RoomVersionPseudoIDs && ev.Type() == spec.MRoomPowerLevels && ev.StateKeyEquals("") { + var newEvent gomatrixserverlib.PDU + newEvent, err = p.updatePowerLevelEvent(ctx, ev.(*rstypes.HeaderedEvent)) + if err != nil { + return r.From, err + } + delta.StateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent} + } + } if len(delta.StateEvents) > 0 { if last := delta.StateEvents[len(delta.StateEvents)-1]; last != nil { @@ -359,7 +399,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } switch delta.Membership { - case gomatrixserverlib.Join: + case spec.Join: jr := types.NewJoinResponse() if hasMembershipChange { jr.Summary, err = snapshot.GetRoomSummary(ctx, delta.RoomID, device.UserID) @@ -368,39 +408,120 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.HeaderedToClientEvents(events, synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.HeaderedToClientEvents(delta.StateEvents, synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) req.Response.Rooms.Join[delta.RoomID] = jr - case gomatrixserverlib.Peek: + case spec.Peek: jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.HeaderedToClientEvents(recentEvents, synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.HeaderedToClientEvents(delta.StateEvents, synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) req.Response.Rooms.Peek[delta.RoomID] = jr - case gomatrixserverlib.Leave: + case spec.Leave: fallthrough // transitions to leave are the same as ban - case gomatrixserverlib.Ban: + case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.HeaderedToClientEvents(events, synctypes.FormatSync) + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.HeaderedToClientEvents(delta.StateEvents, synctypes.FormatSync) + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) req.Response.Rooms.Leave[delta.RoomID] = lr } return latestPosition, nil } +func (p *PDUStreamProvider) updatePowerLevelEvent(ctx context.Context, ev *rstypes.HeaderedEvent) (gomatrixserverlib.PDU, error) { + pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev) + if err != nil { + return nil, err + } + newPls := make(map[string]int64) + var userID *spec.UserID + for user, level := range pls.Users { + validRoomID, _ := spec.NewRoomID(ev.RoomID()) + userID, err = p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(user)) + if err != nil { + return nil, err + } + newPls[userID.String()] = level + } + var newPlBytes, newEv []byte + newPlBytes, err = json.Marshal(newPls) + if err != nil { + return nil, err + } + newEv, err = sjson.SetRawBytes(ev.JSON(), "content.users", newPlBytes) + if err != nil { + return nil, err + } + + // do the same for prev content + prevContent := gjson.GetBytes(ev.JSON(), "unsigned.prev_content") + if !prevContent.Exists() { + var evNew gomatrixserverlib.PDU + evNew, err = gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON(newEv, false) + if err != nil { + return nil, err + } + + return evNew, err + } + pls = gomatrixserverlib.PowerLevelContent{} + err = json.Unmarshal([]byte(prevContent.Raw), &pls) + if err != nil { + return nil, err + } + + newPls = make(map[string]int64) + for user, level := range pls.Users { + validRoomID, _ := spec.NewRoomID(ev.RoomID()) + userID, err = p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(user)) + if err != nil { + return nil, err + } + newPls[userID.String()] = level + } + newPlBytes, err = json.Marshal(newPls) + if err != nil { + return nil, err + } + newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes) + if err != nil { + return nil, err + } + + var evNew gomatrixserverlib.PDU + evNew, err = gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionPseudoIDs).NewEventFromTrustedJSON(newEv, false) + if err != nil { + return nil, err + } + + return evNew, err +} + // applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make // sure we always return the required events in the timeline. func applyHistoryVisibilityFilter( @@ -408,8 +529,8 @@ func applyHistoryVisibilityFilter( snapshot storage.DatabaseTransaction, rsAPI roomserverAPI.SyncRoomserverAPI, roomID, userID string, - recentEvents []*gomatrixserverlib.HeaderedEvent, -) ([]*gomatrixserverlib.HeaderedEvent, error) { + recentEvents []*rstypes.HeaderedEvent, +) ([]*rstypes.HeaderedEvent, error) { // We need to make sure we always include the latest state events, if they are in the timeline. alwaysIncludeIDs := make(map[string]struct{}) var stateTypes []string @@ -417,7 +538,7 @@ func applyHistoryVisibilityFilter( for _, ev := range recentEvents { if ev.StateKey() != nil { stateTypes = append(stateTypes, ev.Type()) - senders = append(senders, ev.Sender()) + senders = append(senders, string(ev.SenderID())) } } @@ -450,6 +571,7 @@ func applyHistoryVisibilityFilter( return events, nil } +// nolint: gocyclo func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, snapshot storage.DatabaseTransaction, @@ -492,7 +614,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // "Can sync a room with a message with a transaction id" - which does a complete sync to check. - recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) + recentEvents := snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI) events := recentEvents // Only apply history visibility checks if the response is for joined rooms @@ -529,7 +651,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( event := events[0] // If this is the beginning of the room, we can't go back further. We're going to return // the TopologyToken from the last event instead. (Synapse returns the /sync next_Batch) - if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + if event.Type() == spec.MRoomCreate && event.StateKeyEquals("") { event = events[len(events)-1] } backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, event.EventID()) @@ -543,12 +665,45 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( prevBatch.Decrement() } + // Update powerlevel events for timeline events + for i, ev := range events { + if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { + continue + } + if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { + continue + } + newEvent, err := p.updatePowerLevelEvent(ctx, ev) + if err != nil { + return nil, err + } + events[i] = &rstypes.HeaderedEvent{PDU: newEvent} + } + // Update powerlevel events for state events + for i, ev := range stateEvents { + if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { + continue + } + if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { + continue + } + newEvent, err := p.updatePowerLevelEvent(ctx, ev) + if err != nil { + return nil, err + } + stateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent} + } + jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.HeaderedToClientEvents(events, synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.HeaderedToClientEvents(stateEvents, synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) return jr, nil } @@ -556,11 +711,15 @@ func (p *PDUStreamProvider) lazyLoadMembers( ctx context.Context, snapshot storage.DatabaseTransaction, roomID string, incremental, limited bool, stateFilter *synctypes.StateFilter, device *userapi.Device, - timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, -) ([]*gomatrixserverlib.HeaderedEvent, error) { + timelineEvents, stateEvents []*rstypes.HeaderedEvent, +) ([]*rstypes.HeaderedEvent, error) { if len(timelineEvents) == 0 { return stateEvents, nil } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } // Work out which memberships to include timelineUsers := make(map[string]struct{}) if !incremental { @@ -569,25 +728,29 @@ func (p *PDUStreamProvider) lazyLoadMembers( // Add all users the client doesn't know about yet to a list for _, event := range timelineEvents { // Membership is not yet cached, add it to the list - if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok { - timelineUsers[event.Sender()] = struct{}{} + if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, string(event.SenderID())); !ok { + timelineUsers[string(event.SenderID())] = struct{}{} } } // Preallocate with the same amount, even if it will end up with fewer values - newStateEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(stateEvents)) + newStateEvents := make([]*rstypes.HeaderedEvent, 0, len(stateEvents)) // Remove existing membership events we don't care about, e.g. users not in the timeline.events for _, event := range stateEvents { - if event.Type() == gomatrixserverlib.MRoomMember && event.StateKey() != nil { + if event.Type() == spec.MRoomMember && event.StateKey() != nil { // If this is a gapped incremental sync, we still want this membership isGappedIncremental := limited && incremental // We want this users membership event, keep it in the list - stateKey := *event.StateKey() - if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID { + userID := "" + stateKeyUserID, queryErr := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) + if queryErr == nil && stateKeyUserID != nil { + userID = stateKeyUserID.String() + } + if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID { newStateEvents = append(newStateEvents, event) if !stateFilter.IncludeRedundantMembers { - p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID()) + p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, userID, event.EventID()) } - delete(timelineUsers, stateKey) + delete(timelineUsers, userID) } } else { newStateEvents = append(newStateEvents, event) @@ -600,7 +763,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( // Query missing membership events filter := synctypes.DefaultStateFilter() filter.Senders = &wantUsers - filter.Types = &[]string{gomatrixserverlib.MRoomMember} + filter.Types = &[]string{spec.MRoomMember} memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter) if err != nil { return stateEvents, err @@ -634,7 +797,7 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapsho return nil } -func removeDuplicates(stateEvents, recentEvents []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { +func removeDuplicates[T gomatrixserverlib.PDU](stateEvents, recentEvents []T) []T { for _, recentEv := range recentEvents { if recentEv.StateKey() == nil { continue // not a state event diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 3ef58bc1cf..69364cd08c 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -19,13 +19,13 @@ import ( "encoding/json" "fmt" - "github.com/matrix-org/gomatrixserverlib" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) type PresenceStreamProvider struct { @@ -118,7 +118,7 @@ func (p *PresenceStreamProvider) IncrementalSync( req.Response.Presence.Events = append(req.Response.Presence.Events, synctypes.ClientEvent{ Content: content, Sender: presence.UserID, - Type: gomatrixserverlib.MPresence, + Type: spec.MPresence, }) if presence.StreamPos > lastPos { lastPos = presence.StreamPos @@ -190,7 +190,7 @@ func membershipEventPresent(events []synctypes.ClientEvent, userID string) bool for _, ev := range events { // it's enough to know that we have our member event here, don't need to check membership content // as it's implied by being in the respective section of the sync response. - if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { + if ev.Type == spec.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID { // ignore e.g. join -> join changes if gjson.GetBytes(ev.Unsigned, "prev_content.membership").Str == gjson.GetBytes(ev.Content, "membership").Str { continue diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index 88db0054e3..ed52dc5c70 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" @@ -46,7 +46,7 @@ func (p *ReceiptStreamProvider) IncrementalSync( ) types.StreamPosition { var joinedRooms []string for roomID, membership := range req.Rooms { - if membership == gomatrixserverlib.Join { + if membership == spec.Join { joinedRooms = append(joinedRooms, roomID) } } @@ -88,7 +88,7 @@ func (p *ReceiptStreamProvider) IncrementalSync( } ev := synctypes.ClientEvent{ - Type: gomatrixserverlib.MReceipt, + Type: spec.MReceipt, } content := make(map[string]ReceiptMRead) for _, receipt := range receipts { @@ -119,5 +119,5 @@ type ReceiptMRead struct { } type ReceiptTS struct { - TS gomatrixserverlib.Timestamp `json:"ts"` + TS spec.Timestamp `json:"ts"` } diff --git a/syncapi/streams/stream_typing.go b/syncapi/streams/stream_typing.go index b0e7d9e7c2..15500a470f 100644 --- a/syncapi/streams/stream_typing.go +++ b/syncapi/streams/stream_typing.go @@ -4,12 +4,11 @@ import ( "context" "encoding/json" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) type TypingStreamProvider struct { @@ -33,7 +32,7 @@ func (p *TypingStreamProvider) IncrementalSync( ) types.StreamPosition { var err error for roomID, membership := range req.Rooms { - if membership != gomatrixserverlib.Join { + if membership != spec.Join { continue } @@ -53,7 +52,7 @@ func (p *TypingStreamProvider) IncrementalSync( } } ev := synctypes.ClientEvent{ - Type: gomatrixserverlib.MTyping, + Type: spec.MTyping, } ev.Content, err = json.Marshal(map[string]interface{}{ "user_ids": typingUsers, diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index 5c112ed459..ef75e17976 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -48,6 +48,7 @@ func NewSyncStreamProviders( }, InviteStreamProvider: &InviteStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, + rsAPI: rsAPI, }, SendToDeviceStreamProvider: &SendToDeviceStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 22ee340bb8..57a7fc954e 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -25,12 +25,11 @@ import ( "sync" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/sqlutil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -61,7 +60,7 @@ type PresencePublisher interface { } type PresenceConsumer interface { - EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts gomatrixserverlib.Timestamp, fromSync bool) + EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts spec.Timestamp, fromSync bool) } // NewRequestPool makes a new RequestPool @@ -138,7 +137,7 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user newPresence := types.PresenceInternal{ Presence: presenceID, UserID: userID, - LastActiveTS: gomatrixserverlib.AsTimestamp(time.Now()), + LastActiveTS: spec.AsTimestamp(time.Now()), } // ensure we also send the current status_msg to federated servers and not nil @@ -170,7 +169,7 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user // the /sync response else we may not return presence: online immediately. rp.consumer.EmitPresence( context.Background(), userID, presenceID, newPresence.ClientFields.StatusMsg, - gomatrixserverlib.AsTimestamp(time.Now()), true, + spec.AsTimestamp(time.Now()), true, ) } @@ -236,12 +235,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. if err == types.ErrMalformedSyncToken { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), + JSON: spec.InvalidParam(err.Error()), } } return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Unknown(err.Error()), + JSON: spec.Unknown(err.Error()), } } @@ -538,32 +537,38 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use if from == "" || to == "" { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("missing ?from= or ?to="), + JSON: spec.InvalidParam("missing ?from= or ?to="), } } fromToken, err := types.NewStreamTokenFromString(from) if err != nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("bad 'from' value"), + JSON: spec.InvalidParam("bad 'from' value"), } } toToken, err := types.NewStreamTokenFromString(to) if err != nil { return util.JSONResponse{ Code: 400, - JSON: jsonerror.InvalidArgumentValue("bad 'to' value"), + JSON: spec.InvalidParam("bad 'to' value"), } } syncReq, err := newSyncRequest(req, *device, rp.db) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } snapshot, err := rp.db.NewDatabaseSnapshot(req.Context()) if err != nil { logrus.WithError(err).Error("Failed to acquire database snapshot for key change") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } var succeeded bool defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) @@ -574,7 +579,10 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("Failed to DeviceListCatchup info") - return jsonerror.InternalServerError() + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } succeeded = true return util.JSONResponse{ diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index cd0973dbc1..1b242ad006 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -10,7 +10,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type dummyPublisher struct { @@ -29,7 +29,7 @@ type dummyDB struct { storage.Database } -func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { +func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS spec.Timestamp, fromSync bool) (types.StreamPosition, error) { return 0, nil } @@ -47,7 +47,7 @@ func (d dummyDB) MaxStreamPositionForPresence(ctx context.Context) (types.Stream type dummyConsumer struct{} -func (d dummyConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts gomatrixserverlib.Timestamp, fromSync bool) { +func (d dummyConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts spec.Timestamp, fromSync bool) { } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 470cfcd5f6..a2f0bc6a46 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -74,7 +74,7 @@ func AddPublicRoutes( }() eduCache := caching.NewTypingCache() - notifier := notifier.NewNotifier() + notifier := notifier.NewNotifier(rsAPI) streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier, mrq) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index a39de64860..1080cdbb85 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -15,9 +15,11 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/tidwall/gjson" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" @@ -38,6 +40,10 @@ type syncRoomserverAPI struct { rooms []*test.Room } +func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { var room *test.Room for _, r := range s.rooms { @@ -489,7 +495,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { afterJoinBody := fmt.Sprintf("After join in a %s room", tc.historyVisibility) msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": afterJoinBody}) - eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) + eventsToSend = append([]*rstypes.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) @@ -522,7 +528,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { } } -func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatrixserverlib.HeaderedEvent, chunk []synctypes.ClientEvent) { +func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *rstypes.HeaderedEvent, chunk []synctypes.ClientEvent) { t.Helper() if wantVisible { for _, ev := range chunk { @@ -612,10 +618,10 @@ func TestGetMembership(t *testing.T) { })) }, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(alice.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) }, @@ -631,10 +637,10 @@ func TestGetMembership(t *testing.T) { })) }, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(alice.ID)) }, @@ -650,7 +656,7 @@ func TestGetMembership(t *testing.T) { })) }, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(alice.ID)) }, @@ -666,7 +672,7 @@ func TestGetMembership(t *testing.T) { })) }, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) }, @@ -706,10 +712,10 @@ func TestGetMembership(t *testing.T) { })) }, additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID)) - room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ "membership": "leave", }, test.WithStateKey(bob.ID)) }, @@ -1141,7 +1147,7 @@ func TestUpdateRelations(t *testing.T) { }, { name: "redactions are ignored", - eventType: gomatrixserverlib.MRoomRedaction, + eventType: spec.MRoomRedaction, eventContent: map[string]interface{}{ "m.relates_to": map[string]interface{}{ "event_id": "$randomEventID", @@ -1227,7 +1233,7 @@ func syncUntil(t *testing.T, } } -func toNATSMsgs(t *testing.T, cfg *config.Dendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg { +func toNATSMsgs(t *testing.T, cfg *config.Dendrite, input ...*rstypes.HeaderedEvent) []*nats.Msg { result := make([]*nats.Msg, len(input)) for i, ev := range input { var addsStateIDs []string diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 0d1e85bcc5..6f03d9ff0e 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -15,7 +15,10 @@ package synctypes -import "github.com/matrix-org/gomatrixserverlib" +import ( + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" +) type ClientEventFormat int @@ -29,49 +32,56 @@ const ( // ClientEvent is an event which is fit for consumption by clients, in accordance with the specification. type ClientEvent struct { - Content gomatrixserverlib.RawJSON `json:"content"` - EventID string `json:"event_id,omitempty"` // EventID is omitted on receipt events - OriginServerTS gomatrixserverlib.Timestamp `json:"origin_server_ts,omitempty"` // OriginServerTS is omitted on receipt events - RoomID string `json:"room_id,omitempty"` // RoomID is omitted on /sync responses - Sender string `json:"sender,omitempty"` // Sender is omitted on receipt events - StateKey *string `json:"state_key,omitempty"` - Type string `json:"type"` - Unsigned gomatrixserverlib.RawJSON `json:"unsigned,omitempty"` - Redacts string `json:"redacts,omitempty"` + Content spec.RawJSON `json:"content"` + EventID string `json:"event_id,omitempty"` // EventID is omitted on receipt events + OriginServerTS spec.Timestamp `json:"origin_server_ts,omitempty"` // OriginServerTS is omitted on receipt events + RoomID string `json:"room_id,omitempty"` // RoomID is omitted on /sync responses + Sender string `json:"sender,omitempty"` // Sender is omitted on receipt events + SenderKey spec.SenderID `json:"sender_key,omitempty"` // The SenderKey for events in pseudo ID rooms + StateKey *string `json:"state_key,omitempty"` + Type string `json:"type"` + Unsigned spec.RawJSON `json:"unsigned,omitempty"` + Redacts string `json:"redacts,omitempty"` } // ToClientEvents converts server events to client events. -func ToClientEvents(serverEvs []*gomatrixserverlib.Event, format ClientEventFormat) []ClientEvent { +func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) []ClientEvent { evs := make([]ClientEvent, 0, len(serverEvs)) for _, se := range serverEvs { if se == nil { continue // TODO: shouldn't happen? } - evs = append(evs, ToClientEvent(se, format)) - } - return evs -} + sender := spec.UserID{} + validRoomID, err := spec.NewRoomID(se.RoomID()) + if err != nil { + continue + } + userID, err := userIDForSender(*validRoomID, se.SenderID()) + if err == nil && userID != nil { + sender = *userID + } -// HeaderedToClientEvents converts headered server events to client events. -func HeaderedToClientEvents(serverEvs []*gomatrixserverlib.HeaderedEvent, format ClientEventFormat) []ClientEvent { - evs := make([]ClientEvent, 0, len(serverEvs)) - for _, se := range serverEvs { - if se == nil { - continue // TODO: shouldn't happen? + sk := se.StateKey() + if sk != nil && *sk != "" { + skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk)) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } } - evs = append(evs, HeaderedToClientEvent(se, format)) + evs = append(evs, ToClientEvent(se, format, sender, sk)) } return evs } // ToClientEvent converts a single server event to a client event. -func ToClientEvent(se *gomatrixserverlib.Event, format ClientEventFormat) ClientEvent { +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent { ce := ClientEvent{ - Content: gomatrixserverlib.RawJSON(se.Content()), - Sender: se.Sender(), + Content: spec.RawJSON(se.Content()), + Sender: sender.String(), Type: se.Type(), - StateKey: se.StateKey(), - Unsigned: gomatrixserverlib.RawJSON(se.Unsigned()), + StateKey: stateKey, + Unsigned: spec.RawJSON(se.Unsigned()), OriginServerTS: se.OriginServerTS(), EventID: se.EventID(), Redacts: se.Redacts(), @@ -79,10 +89,32 @@ func ToClientEvent(se *gomatrixserverlib.Event, format ClientEventFormat) Client if format == FormatAll { ce.RoomID = se.RoomID() } + if se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { + ce.SenderKey = se.SenderID() + } return ce } -// HeaderedToClientEvent converts a single headered server event to a client event. -func HeaderedToClientEvent(se *gomatrixserverlib.HeaderedEvent, format ClientEventFormat) ClientEvent { - return ToClientEvent(se.Event, format) +// ToClientEvent converts a single server event to a client event. +// It provides default logic for event.SenderID & event.StateKey -> userID conversions. +func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { + sender := spec.UserID{} + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return ClientEvent{} + } + userID, err := userIDQuery(*validRoomID, event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, err := userIDQuery(*validRoomID, spec.SenderID(*event.StateKey())) + if err == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + return ToClientEvent(event, FormatAll, sender, sk) } diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index ac07917ab0..63c65b2af3 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -21,10 +21,11 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func TestToClientEvent(t *testing.T) { // nolint: gocyclo - ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ + ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{ "type": "m.room.name", "state_key": "", "event_id": "$test:localhost", @@ -39,11 +40,16 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo "name": "Goodbye World" } } - }`), false, gomatrixserverlib.RoomVersionV1) + }`), false) if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatAll) + userID, err := spec.NewUserID("@test:localhost", true) + if err != nil { + t.Fatalf("failed to create userID: %s", err) + } + sk := "" + ce := ToClientEvent(ev, FormatAll, *userID, &sk) if ce.EventID != ev.EventID() { t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) } @@ -62,8 +68,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) } - if ce.Sender != ev.Sender() { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", ev.Sender(), ce.Sender) + if ce.Sender != userID.String() { + t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender) } j, err := json.Marshal(ce) if err != nil { @@ -79,7 +85,7 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo } func TestToClientFormatSync(t *testing.T) { - ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ + ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{ "type": "m.room.name", "state_key": "", "event_id": "$test:localhost", @@ -94,11 +100,16 @@ func TestToClientFormatSync(t *testing.T) { "name": "Goodbye World" } } - }`), false, gomatrixserverlib.RoomVersionV1) + }`), false) if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatSync) + userID, err := spec.NewUserID("@test:localhost", true) + if err != nil { + t.Fatalf("failed to create userID: %s", err) + } + sk := "" + ce := ToClientEvent(ev, FormatSync, *userID, &sk) if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } diff --git a/syncapi/types/presence.go b/syncapi/types/presence.go index 760225de88..32dc1d828e 100644 --- a/syncapi/types/presence.go +++ b/syncapi/types/presence.go @@ -18,7 +18,7 @@ import ( "strings" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const ( @@ -66,10 +66,10 @@ func PresenceFromString(input string) (Presence, bool) { type PresenceInternal struct { ClientFields PresenceClientResponse - StreamPos StreamPosition `json:"-"` - UserID string `json:"-"` - LastActiveTS gomatrixserverlib.Timestamp `json:"-"` - Presence Presence `json:"-"` + StreamPos StreamPosition `json:"-"` + UserID string `json:"-"` + LastActiveTS spec.Timestamp `json:"-"` + Presence Presence `json:"-"` } type PresenceNotify struct { diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index f2edadea88..e0de2b592f 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -4,11 +4,11 @@ import ( "context" "time" - "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/syncapi/synctypes" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" ) type SyncRequest struct { @@ -36,11 +36,11 @@ func (r *SyncRequest) IsRoomPresent(roomID string) bool { return false } switch membership { - case gomatrixserverlib.Join: + case spec.Join: return true - case gomatrixserverlib.Invite: + case spec.Invite: return true - case gomatrixserverlib.Peek: + case spec.Peek: return true default: return false diff --git a/syncapi/types/types.go b/syncapi/types/types.go index fbc8e36588..06ba464c7d 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -22,10 +22,12 @@ import ( "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" ) @@ -38,7 +40,7 @@ var ( type StateDelta struct { RoomID string - StateEvents []*gomatrixserverlib.HeaderedEvent + StateEvents []*types.HeaderedEvent NewlyJoined bool Membership string // The PDU stream position of the latest membership event for this user, if applicable. @@ -59,7 +61,7 @@ func NewStreamPositionFromString(s string) (StreamPosition, error) { // StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. type StreamEvent struct { - *gomatrixserverlib.HeaderedEvent + *types.HeaderedEvent StreamPosition StreamPosition TransactionID *api.TransactionID ExcludeFromSync bool @@ -352,7 +354,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { type PrevEventRef struct { PrevContent json.RawMessage `json:"prev_content"` ReplacesState string `json:"replaces_state"` - PrevSender string `json:"prev_sender"` + PrevSenderID string `json:"prev_sender"` } type DeviceLists struct { @@ -550,7 +552,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *gomatrixserverlib.HeaderedEvent) *InviteResponse { +func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -563,7 +565,7 @@ func NewInviteResponse(event *gomatrixserverlib.HeaderedEvent) *InviteResponse { // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.Unwrap(), synctypes.FormatSync) + inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey) inviteEvent.Unsigned = nil if ev, err := json.Marshal(inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) @@ -619,11 +621,11 @@ type Peek struct { // OutputReceiptEvent is an entry in the receipt output kafka log type OutputReceiptEvent struct { - UserID string `json:"user_id"` - RoomID string `json:"room_id"` - EventID string `json:"event_id"` - Type string `json:"type"` - Timestamp gomatrixserverlib.Timestamp `json:"timestamp"` + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + Type string `json:"type"` + Timestamp spec.Timestamp `json:"timestamp"` } // OutputSendToDeviceEvent is an entry in the send-to-device output kafka log. diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index cdb22b7ee0..45fc44a474 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -5,10 +5,16 @@ import ( "reflect" "testing" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestSyncTokens(t *testing.T) { shouldPass := map[string]string{ "s4_0_0_0_0_0_0_0_3_0": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3, 0}.String(), @@ -50,12 +56,23 @@ func TestNewInviteResponse(t *testing.T) { event := `{"auth_events":["$SbSsh09j26UAXnjd3RZqf2lyA3Kw2sY_VZJVZQAV9yA","$EwL53onrLwQ5gL8Dv3VrOOCvHiueXu2ovLdzqkNi3lo","$l2wGmz9iAwevBDGpHT_xXLUA5O8BhORxWIGU1cGi1ZM","$GsWFJLXgdlF5HpZeyWkP72tzXYWW3uQ9X28HBuTztHE"],"content":{"avatar_url":"","displayname":"neilalexander","membership":"invite"},"depth":9,"hashes":{"sha256":"8p+Ur4f8vLFX6mkIXhxI0kegPG7X3tWy56QmvBkExAg"},"origin":"matrix.org","origin_server_ts":1602087113066,"prev_events":["$1v-O6tNwhOZcA8bvCYY-Dnj1V2ZDE58lLPxtlV97S28"],"prev_state":[],"room_id":"!XbeXirGWSPXbEaGokF:matrix.org","sender":"@neilalexander:matrix.org","signatures":{"dendrite.neilalexander.dev":{"ed25519:BMJi":"05KQ5lPw0cSFsE4A0x1z7vi/3cc8bG4WHUsFWYkhxvk/XkXMGIYAYkpNThIvSeLfdcHlbm/k10AsBSKH8Uq4DA"},"matrix.org":{"ed25519:a_RXGa":"jeovuHr9E/x0sHbFkdfxDDYV/EyoeLi98douZYqZ02iYddtKhfB7R3WLay/a+D3V3V7IW0FUmPh/A404x5sYCw"}},"state_key":"@neilalexander:dendrite.neilalexander.dev","type":"m.room.member","unsigned":{"age":2512,"invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@neilalexander:matrix.org","state_key":"","type":"m.room.join_rules"},{"content":{"avatar_url":"mxc://matrix.org/BpDaozLwgLnlNStxDxvLzhPr","displayname":"neilalexander","membership":"join"},"sender":"@neilalexander:matrix.org","state_key":"@neilalexander:matrix.org","type":"m.room.member"},{"content":{"name":"Test room"},"sender":"@neilalexander:matrix.org","state_key":"","type":"m.room.name"}]},"_room_version":"5"}` expected := `{"invite_state":{"events":[{"content":{"join_rule":"invite"},"sender":"@neilalexander:matrix.org","state_key":"","type":"m.room.join_rules"},{"content":{"avatar_url":"mxc://matrix.org/BpDaozLwgLnlNStxDxvLzhPr","displayname":"neilalexander","membership":"join"},"sender":"@neilalexander:matrix.org","state_key":"@neilalexander:matrix.org","type":"m.room.member"},{"content":{"name":"Test room"},"sender":"@neilalexander:matrix.org","state_key":"","type":"m.room.name"},{"content":{"avatar_url":"","displayname":"neilalexander","membership":"invite"},"event_id":"$GQmw8e8-26CQv1QuFoHBHpKF1hQj61Flg3kvv_v_XWs","origin_server_ts":1602087113066,"sender":"@neilalexander:matrix.org","state_key":"@neilalexander:dendrite.neilalexander.dev","type":"m.room.member"}]}}` - ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(event), false, gomatrixserverlib.RoomVersionV5) + ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV5).NewEventFromTrustedJSON([]byte(event), false) + if err != nil { + t.Fatal(err) + } + + sender, err := spec.NewUserID("@neilalexander:matrix.org", true) + if err != nil { + t.Fatal(err) + } + skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true) if err != nil { t.Fatal(err) } + skString := skUserID.String() + sk := &skString - res := NewInviteResponse(ev.Headered(gomatrixserverlib.RoomVersionV5)) + res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk) j, err := json.Marshal(res) if err != nil { t.Fatal(err) diff --git a/test/event.go b/test/event.go index 0c7bf43551..197f80e13e 100644 --- a/test/event.go +++ b/test/event.go @@ -20,12 +20,14 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type eventMods struct { originServerTS time.Time - origin gomatrixserverlib.ServerName + origin spec.ServerName stateKey *string unsigned interface{} keyID gomatrixserverlib.KeyID @@ -71,22 +73,22 @@ func WithPrivateKey(pkey ed25519.PrivateKey) eventModifier { } } -func WithOrigin(origin gomatrixserverlib.ServerName) eventModifier { +func WithOrigin(origin spec.ServerName) eventModifier { return func(e *eventMods) { e.origin = origin } } // Reverse a list of events -func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { - out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) +func Reversed(in []*types.HeaderedEvent) []*types.HeaderedEvent { + out := make([]*types.HeaderedEvent, len(in)) for i := 0; i < len(in); i++ { out[i] = in[len(in)-i-1] } return out } -func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) { +func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*types.HeaderedEvent) { t.Helper() if len(gotEventIDs) != len(wants) { t.Errorf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants)) @@ -101,7 +103,7 @@ func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixse } } -func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) { +func AssertEventsEqual(t *testing.T, gots, wants []*types.HeaderedEvent) { t.Helper() if len(gots) != len(wants) { t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants)) diff --git a/test/http.go b/test/http.go index 8cd83d0a6a..9a7223b8e2 100644 --- a/test/http.go +++ b/test/http.go @@ -52,7 +52,7 @@ func NewRequest(t *testing.T, method, path string, opts ...HTTPRequestOpt) *http // ListenAndServe will listen on a random high-numbered port and attach the given router. // Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed. func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL string, cancel func()) { - listener, err := net.Listen("tcp", ":0") + listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to listen: %s", err) } diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go index de0dc54eb5..76034143fa 100644 --- a/test/memory_federation_db.go +++ b/test/memory_federation_db.go @@ -23,7 +23,9 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/types" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) var nidMutex sync.Mutex @@ -31,28 +33,28 @@ var nid = int64(0) type InMemoryFederationDatabase struct { dbMutex sync.Mutex - pendingPDUServers map[gomatrixserverlib.ServerName]struct{} - pendingEDUServers map[gomatrixserverlib.ServerName]struct{} - blacklistedServers map[gomatrixserverlib.ServerName]struct{} - assumedOffline map[gomatrixserverlib.ServerName]struct{} - pendingPDUs map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent + pendingPDUServers map[spec.ServerName]struct{} + pendingEDUServers map[spec.ServerName]struct{} + blacklistedServers map[spec.ServerName]struct{} + assumedOffline map[spec.ServerName]struct{} + pendingPDUs map[*receipt.Receipt]*rstypes.HeaderedEvent pendingEDUs map[*receipt.Receipt]*gomatrixserverlib.EDU - associatedPDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} - associatedEDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} - relayServers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName + associatedPDUs map[spec.ServerName]map[*receipt.Receipt]struct{} + associatedEDUs map[spec.ServerName]map[*receipt.Receipt]struct{} + relayServers map[spec.ServerName][]spec.ServerName } func NewInMemoryFederationDatabase() *InMemoryFederationDatabase { return &InMemoryFederationDatabase{ - pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), - assumedOffline: make(map[gomatrixserverlib.ServerName]struct{}), - pendingPDUs: make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingPDUServers: make(map[spec.ServerName]struct{}), + pendingEDUServers: make(map[spec.ServerName]struct{}), + blacklistedServers: make(map[spec.ServerName]struct{}), + assumedOffline: make(map[spec.ServerName]struct{}), + pendingPDUs: make(map[*receipt.Receipt]*rstypes.HeaderedEvent), pendingEDUs: make(map[*receipt.Receipt]*gomatrixserverlib.EDU), - associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), - associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), - relayServers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName), + associatedPDUs: make(map[spec.ServerName]map[*receipt.Receipt]struct{}), + associatedEDUs: make(map[spec.ServerName]map[*receipt.Receipt]struct{}), + relayServers: make(map[spec.ServerName][]spec.ServerName), } } @@ -63,7 +65,7 @@ func (d *InMemoryFederationDatabase) StoreJSON( d.dbMutex.Lock() defer d.dbMutex.Unlock() - var event gomatrixserverlib.HeaderedEvent + var event rstypes.HeaderedEvent if err := json.Unmarshal([]byte(js), &event); err == nil { nidMutex.Lock() defer nidMutex.Unlock() @@ -88,14 +90,14 @@ func (d *InMemoryFederationDatabase) StoreJSON( func (d *InMemoryFederationDatabase) GetPendingPDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, -) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { +) (pdus map[*receipt.Receipt]*rstypes.HeaderedEvent, err error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() pduCount := 0 - pdus = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) + pdus = make(map[*receipt.Receipt]*rstypes.HeaderedEvent) if receipts, ok := d.associatedPDUs[serverName]; ok { for dbReceipt := range receipts { if event, ok := d.pendingPDUs[dbReceipt]; ok { @@ -112,7 +114,7 @@ func (d *InMemoryFederationDatabase) GetPendingPDUs( func (d *InMemoryFederationDatabase) GetPendingEDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, limit int, ) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) { d.dbMutex.Lock() @@ -136,7 +138,7 @@ func (d *InMemoryFederationDatabase) GetPendingEDUs( func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( ctx context.Context, - destinations map[gomatrixserverlib.ServerName]struct{}, + destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt, ) error { d.dbMutex.Lock() @@ -158,7 +160,7 @@ func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations( ctx context.Context, - destinations map[gomatrixserverlib.ServerName]struct{}, + destinations map[spec.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, @@ -182,7 +184,7 @@ func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations( func (d *InMemoryFederationDatabase) CleanPDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, receipts []*receipt.Receipt, ) error { d.dbMutex.Lock() @@ -199,7 +201,7 @@ func (d *InMemoryFederationDatabase) CleanPDUs( func (d *InMemoryFederationDatabase) CleanEDUs( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, receipts []*receipt.Receipt, ) error { d.dbMutex.Lock() @@ -216,7 +218,7 @@ func (d *InMemoryFederationDatabase) CleanEDUs( func (d *InMemoryFederationDatabase) GetPendingPDUCount( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (int64, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -230,7 +232,7 @@ func (d *InMemoryFederationDatabase) GetPendingPDUCount( func (d *InMemoryFederationDatabase) GetPendingEDUCount( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (int64, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -244,11 +246,11 @@ func (d *InMemoryFederationDatabase) GetPendingEDUCount( func (d *InMemoryFederationDatabase) GetPendingPDUServerNames( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() - servers := []gomatrixserverlib.ServerName{} + servers := []spec.ServerName{} for server := range d.pendingPDUServers { servers = append(servers, server) } @@ -257,11 +259,11 @@ func (d *InMemoryFederationDatabase) GetPendingPDUServerNames( func (d *InMemoryFederationDatabase) GetPendingEDUServerNames( ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { +) ([]spec.ServerName, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() - servers := []gomatrixserverlib.ServerName{} + servers := []spec.ServerName{} for server := range d.pendingEDUServers { servers = append(servers, server) } @@ -269,7 +271,7 @@ func (d *InMemoryFederationDatabase) GetPendingEDUServerNames( } func (d *InMemoryFederationDatabase) AddServerToBlacklist( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -279,7 +281,7 @@ func (d *InMemoryFederationDatabase) AddServerToBlacklist( } func (d *InMemoryFederationDatabase) RemoveServerFromBlacklist( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -292,12 +294,12 @@ func (d *InMemoryFederationDatabase) RemoveAllServersFromBlacklist() error { d.dbMutex.Lock() defer d.dbMutex.Unlock() - d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + d.blacklistedServers = make(map[spec.ServerName]struct{}) return nil } func (d *InMemoryFederationDatabase) IsServerBlacklisted( - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (bool, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -312,7 +314,7 @@ func (d *InMemoryFederationDatabase) IsServerBlacklisted( func (d *InMemoryFederationDatabase) SetServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -323,7 +325,7 @@ func (d *InMemoryFederationDatabase) SetServerAssumedOffline( func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -338,13 +340,13 @@ func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine( d.dbMutex.Lock() defer d.dbMutex.Unlock() - d.assumedOffline = make(map[gomatrixserverlib.ServerName]struct{}) + d.assumedOffline = make(map[spec.ServerName]struct{}) return nil } func (d *InMemoryFederationDatabase) IsServerAssumedOffline( ctx context.Context, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (bool, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -359,12 +361,12 @@ func (d *InMemoryFederationDatabase) IsServerAssumedOffline( func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, -) ([]gomatrixserverlib.ServerName, error) { + serverName spec.ServerName, +) ([]spec.ServerName, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() - knownRelayServers := []gomatrixserverlib.ServerName{} + knownRelayServers := []spec.ServerName{} if relayServers, ok := d.relayServers[serverName]; ok { knownRelayServers = relayServers } @@ -374,8 +376,8 @@ func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer( func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -401,8 +403,8 @@ func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer( ctx context.Context, - serverName gomatrixserverlib.ServerName, - relayServers []gomatrixserverlib.ServerName, + serverName spec.ServerName, + relayServers []spec.ServerName, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() @@ -426,7 +428,7 @@ func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer( return nil } -func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { +func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { return nil, nil } @@ -446,11 +448,11 @@ func (d *InMemoryFederationDatabase) GetJoinedHosts(ctx context.Context, roomID return nil, nil } -func (d *InMemoryFederationDatabase) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { +func (d *InMemoryFederationDatabase) GetAllJoinedHosts(ctx context.Context) ([]spec.ServerName, error) { return nil, nil } -func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { +func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]spec.ServerName, error) { return nil, nil } @@ -458,19 +460,19 @@ func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context. return nil } -func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { +func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName spec.ServerName) error { return nil } -func (d *InMemoryFederationDatabase) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *InMemoryFederationDatabase) AddOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error { return nil } -func (d *InMemoryFederationDatabase) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *InMemoryFederationDatabase) RenewOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error { return nil } -func (d *InMemoryFederationDatabase) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { +func (d *InMemoryFederationDatabase) GetOutboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { return nil, nil } @@ -478,15 +480,15 @@ func (d *InMemoryFederationDatabase) GetOutboundPeeks(ctx context.Context, roomI return nil, nil } -func (d *InMemoryFederationDatabase) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *InMemoryFederationDatabase) AddInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error { return nil } -func (d *InMemoryFederationDatabase) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *InMemoryFederationDatabase) RenewInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string, renewalInterval int64) error { return nil } -func (d *InMemoryFederationDatabase) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { +func (d *InMemoryFederationDatabase) GetInboundPeek(ctx context.Context, serverName spec.ServerName, roomID, peekID string) (*types.InboundPeek, error) { return nil, nil } @@ -494,11 +496,11 @@ func (d *InMemoryFederationDatabase) GetInboundPeeks(ctx context.Context, roomID return nil, nil } -func (d *InMemoryFederationDatabase) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { +func (d *InMemoryFederationDatabase) UpdateNotaryKeys(ctx context.Context, serverName spec.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { return nil } -func (d *InMemoryFederationDatabase) GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { +func (d *InMemoryFederationDatabase) GetNotaryKeys(ctx context.Context, serverName spec.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { return nil, nil } diff --git a/test/memory_relay_db.go b/test/memory_relay_db.go index db93919df7..eecc23fe7d 100644 --- a/test/memory_relay_db.go +++ b/test/memory_relay_db.go @@ -21,13 +21,14 @@ import ( "sync" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) type InMemoryRelayDatabase struct { nid int64 nidMutex sync.Mutex transactions map[int64]json.RawMessage - associations map[gomatrixserverlib.ServerName][]int64 + associations map[spec.ServerName][]int64 } func NewInMemoryRelayDatabase() *InMemoryRelayDatabase { @@ -35,7 +36,7 @@ func NewInMemoryRelayDatabase() *InMemoryRelayDatabase { nid: 1, nidMutex: sync.Mutex{}, transactions: make(map[int64]json.RawMessage), - associations: make(map[gomatrixserverlib.ServerName][]int64), + associations: make(map[spec.ServerName][]int64), } } @@ -43,7 +44,7 @@ func (d *InMemoryRelayDatabase) InsertQueueEntry( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, nid int64, ) error { if _, ok := d.associations[serverName]; !ok { @@ -56,7 +57,7 @@ func (d *InMemoryRelayDatabase) InsertQueueEntry( func (d *InMemoryRelayDatabase) DeleteQueueEntries( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, jsonNIDs []int64, ) error { for _, nid := range jsonNIDs { @@ -72,7 +73,7 @@ func (d *InMemoryRelayDatabase) DeleteQueueEntries( func (d *InMemoryRelayDatabase) SelectQueueEntries( ctx context.Context, - txn *sql.Tx, serverName gomatrixserverlib.ServerName, + txn *sql.Tx, serverName spec.ServerName, limit int, ) ([]int64, error) { results := []int64{} @@ -92,7 +93,7 @@ func (d *InMemoryRelayDatabase) SelectQueueEntries( func (d *InMemoryRelayDatabase) SelectQueueEntryCount( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (int64, error) { return int64(len(d.associations[serverName])), nil } diff --git a/test/room.go b/test/room.go index 685876cb07..da09de7c2a 100644 --- a/test/room.go +++ b/test/room.go @@ -22,8 +22,10 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/eventutil" + rstypes "github.com/matrix-org/dendrite/roomserver/types" ) type Preset int @@ -37,6 +39,10 @@ var ( roomIDCounter = int64(0) ) +func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + type Room struct { ID string Version gomatrixserverlib.RoomVersion @@ -46,8 +52,8 @@ type Room struct { creator *User authEvents gomatrixserverlib.AuthEvents - currentState map[string]*gomatrixserverlib.HeaderedEvent - events []*gomatrixserverlib.HeaderedEvent + currentState map[string]*rstypes.HeaderedEvent + events []*rstypes.HeaderedEvent } // Create a new test room. Automatically creates the initial create events. @@ -63,7 +69,7 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { authEvents: gomatrixserverlib.NewAuthEvents(nil), preset: PresetPublicChat, Version: gomatrixserverlib.RoomVersionV9, - currentState: make(map[string]*gomatrixserverlib.HeaderedEvent), + currentState: make(map[string]*rstypes.HeaderedEvent), visibility: gomatrixserverlib.HistoryVisibilityShared, } for _, m := range modifiers { @@ -73,7 +79,7 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { return r } -func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []gomatrixserverlib.EventReference { +func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []string { t.Helper() a, err := needed.AuthEventReferences(&r.authEvents) if err != nil { @@ -111,25 +117,25 @@ func (r *Room) insertCreateEvents(t *testing.T) { hisVis.HistoryVisibility = r.visibility } - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{ + r.CreateAndInsert(t, r.creator, spec.MRoomCreate, map[string]interface{}{ "creator": r.creator.ID, "room_version": r.Version, }, WithStateKey("")) - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomMember, map[string]interface{}{ + r.CreateAndInsert(t, r.creator, spec.MRoomMember, map[string]interface{}{ "membership": "join", }, WithStateKey(r.creator.ID)) - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) + r.CreateAndInsert(t, r.creator, spec.MRoomPowerLevels, plContent, WithStateKey("")) + r.CreateAndInsert(t, r.creator, spec.MRoomJoinRules, joinRule, WithStateKey("")) + r.CreateAndInsert(t, r.creator, spec.MRoomHistoryVisibility, hisVis, WithStateKey("")) if r.guestCanJoin { - r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{ + r.CreateAndInsert(t, r.creator, spec.MRoomGuestAccess, map[string]string{ "guest_access": "can_join", }, WithStateKey("")) } } // Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe. -func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent { +func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *rstypes.HeaderedEvent { t.Helper() depth := 1 + len(r.events) // depth starts at 1 @@ -152,7 +158,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten mod.origin = creator.srvName } - var unsigned gomatrixserverlib.RawJSON + var unsigned spec.RawJSON var err error if mod.unsigned != nil { unsigned, err = json.Marshal(mod.unsigned) @@ -161,32 +167,26 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten } } - builder := &gomatrixserverlib.EventBuilder{ - Sender: creator.ID, + builder := gomatrixserverlib.MustGetRoomVersion(r.Version).NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{ + SenderID: creator.ID, RoomID: r.ID, Type: eventType, StateKey: mod.stateKey, Depth: int64(depth), Unsigned: unsigned, - } + }) err = builder.SetContent(content) if err != nil { t.Fatalf("CreateEvent[%s]: failed to SetContent: %s", eventType, err) } if depth > 1 { - builder.PrevEvents = []gomatrixserverlib.EventReference{r.events[len(r.events)-1].EventReference()} - } - - eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) - if err != nil { - t.Fatalf("CreateEvent[%s]: failed to StateNeededForEventBuilder: %s", eventType, err) + builder.PrevEvents = []string{r.events[len(r.events)-1].EventID()} } - refs, err := eventsNeeded.AuthEventReferences(&r.authEvents) + err = builder.AddAuthEvents(&r.authEvents) if err != nil { t.Fatalf("CreateEvent[%s]: failed to AuthEventReferences: %s", eventType, err) } - builder.AuthEvents = refs if len(mod.authEvents) > 0 { builder.AuthEvents = mod.authEvents @@ -194,26 +194,26 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten ev, err := builder.Build( mod.originServerTS, mod.origin, mod.keyID, - mod.privKey, r.Version, + mod.privKey, ) if err != nil { t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err) } - if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil { + if err = gomatrixserverlib.Allowed(ev, &r.authEvents, UserIDForSender); err != nil { t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) } - headeredEvent := ev.Headered(r.Version) + headeredEvent := &rstypes.HeaderedEvent{PDU: ev} headeredEvent.Visibility = r.visibility return headeredEvent } // Add a new event to this room DAG. Not thread-safe. -func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) { +func (r *Room) InsertEvent(t *testing.T, he *rstypes.HeaderedEvent) { t.Helper() // Add the event to the list of auth/state events r.events = append(r.events, he) if he.StateKey() != nil { - err := r.authEvents.AddEvent(he.Unwrap()) + err := r.authEvents.AddEvent(he.PDU) if err != nil { t.Fatalf("InsertEvent: failed to add event to auth events: %s", err) } @@ -221,12 +221,12 @@ func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) { } } -func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent { +func (r *Room) Events() []*rstypes.HeaderedEvent { return r.events } -func (r *Room) CurrentState() []*gomatrixserverlib.HeaderedEvent { - events := make([]*gomatrixserverlib.HeaderedEvent, len(r.currentState)) +func (r *Room) CurrentState() []*rstypes.HeaderedEvent { + events := make([]*rstypes.HeaderedEvent, len(r.currentState)) i := 0 for _, e := range r.currentState { events[i] = e @@ -235,7 +235,7 @@ func (r *Room) CurrentState() []*gomatrixserverlib.HeaderedEvent { return events } -func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent { +func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *rstypes.HeaderedEvent { t.Helper() he := r.CreateEvent(t, creator, eventType, content, mods...) r.InsertEvent(t, he) diff --git a/test/user.go b/test/user.go index 63206fa162..9509b95a62 100644 --- a/test/user.go +++ b/test/user.go @@ -23,12 +23,13 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) var ( userIDCounter = int64(0) - serverName = gomatrixserverlib.ServerName("test") + serverName = spec.ServerName("test") keyID = gomatrixserverlib.KeyID("ed25519:test") privateKey = ed25519.NewKeyFromSeed([]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -53,12 +54,12 @@ type User struct { // key ID and private key of the server who has this user, if known. keyID gomatrixserverlib.KeyID privKey ed25519.PrivateKey - srvName gomatrixserverlib.ServerName + srvName spec.ServerName } type UserOpt func(*User) -func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt { +func WithSigningServer(srvName spec.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt { return func(u *User) { u.keyID = keyID u.privKey = privKey diff --git a/userapi/api/api.go b/userapi/api/api.go index 5f1d361b22..a072903a55 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -25,7 +25,9 @@ import ( "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" ) @@ -62,10 +64,10 @@ type FederationUserAPI interface { QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error) QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) } // api functions required by the sync api @@ -86,12 +88,18 @@ type ClientUserAPI interface { UserLoginAPI ClientKeyAPI ProfileAPI + KeyBackupAPI QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error - QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error + QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error + PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) + PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) + PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error + PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -99,13 +107,11 @@ type ClientUserAPI interface { PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error - PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error + PerformPushRulesPut(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error - PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error - QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error @@ -113,10 +119,17 @@ type ClientUserAPI interface { PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error } +type KeyBackupAPI interface { + DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error) + PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest) (string, error) + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest) (*QueryKeyBackupResponse, error) + UpdateBackupKeyAuthData(ctx context.Context, req *PerformKeyBackupRequest) (*PerformKeyBackupResponse, error) +} + type ProfileAPI interface { QueryProfile(ctx context.Context, userID string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) + SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, serverName spec.ServerName, displayName string) (*authtypes.Profile, bool, error) } // custom api functions required by pinecone / p2p demos @@ -136,11 +149,10 @@ type UserLoginAPI interface { } type PerformKeyBackupRequest struct { - UserID string - Version string // optional if modifying a key backup - AuthData json.RawMessage - Algorithm string - DeleteBackup bool // if true will delete the backup based on 'Version'. + UserID string + Version string // optional if modifying a key backup + AuthData json.RawMessage + Algorithm string // The keys to upload, if any. If blank, creates/updates/deletes key version metadata only. Keys struct { @@ -181,9 +193,6 @@ type InternalKeyBackupSession struct { } type PerformKeyBackupResponse struct { - Error string // set if there was a problem performing the request - BadInput bool // if set, the Error was due to bad input (HTTP 400) - Exists bool // set to true if the Version exists Version string // the newly created version @@ -201,7 +210,6 @@ type QueryKeyBackupRequest struct { } type QueryKeyBackupResponse struct { - Error string Exists bool Algorithm string `json:"algorithm"` @@ -315,9 +323,9 @@ type QuerySearchProfilesResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformAccountCreationRequest struct { - AccountType AccountType // Required: whether this is a guest or user account - Localpart string // Required: The localpart for this account. Ignored if account type is guest. - ServerName gomatrixserverlib.ServerName // optional: if not specified, default server name used instead + AccountType AccountType // Required: whether this is a guest or user account + Localpart string // Required: The localpart for this account. Ignored if account type is guest. + ServerName spec.ServerName // optional: if not specified, default server name used instead AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. Password string // optional: if missing then this account will be a passwordless account @@ -332,10 +340,10 @@ type PerformAccountCreationResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformPasswordUpdateRequest struct { - Localpart string // Required: The localpart for this account. - ServerName gomatrixserverlib.ServerName // Required: The domain for this account. - Password string // Required: The new password to set. - LogoutDevices bool // Optional: Whether to log out all user devices. + Localpart string // Required: The localpart for this account. + ServerName spec.ServerName // Required: The domain for this account. + Password string // Required: The new password to set. + LogoutDevices bool // Optional: Whether to log out all user devices. } // PerformAccountCreationResponse is the response for PerformAccountCreation @@ -359,8 +367,8 @@ type PerformLastSeenUpdateResponse struct { // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used - AccessToken string // optional: if blank one will be made on your behalf + ServerName spec.ServerName // optional: if blank, default server name used + AccessToken string // optional: if blank one will be made on your behalf // optional: if nil an ID is generated for you. If set, replaces any existing device session, // which will generate a new access token and invalidate the old one. DeviceID *string @@ -385,7 +393,7 @@ type PerformDeviceCreationResponse struct { // PerformAccountDeactivationRequest is the request for PerformAccountDeactivation type PerformAccountDeactivationRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used + ServerName spec.ServerName // optional: if blank, default server name used } // PerformAccountDeactivationResponse is the response for PerformAccountDeactivation @@ -435,7 +443,7 @@ type Device struct { AccountType AccountType } -func (d *Device) UserDomain() gomatrixserverlib.ServerName { +func (d *Device) UserDomain() spec.ServerName { _, domain, err := gomatrixserverlib.SplitID('@', d.UserID) if err != nil { // This really is catastrophic because it means that someone @@ -451,7 +459,7 @@ func (d *Device) UserDomain() gomatrixserverlib.ServerName { type Account struct { UserID string Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName AppServiceID string AccountType AccountType // TODO: Associations (e.g. with application services) @@ -517,7 +525,7 @@ const ( type QueryPushersRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryPushersResponse struct { @@ -527,13 +535,13 @@ type QueryPushersResponse struct { type PerformPusherSetRequest struct { Pusher // Anonymous field because that's how clientapi unmarshals it. Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName Append bool `json:"append"` } type PerformPusherDeletionRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName SessionID int64 } @@ -558,25 +566,12 @@ const ( HTTPKind PusherKind = "http" ) -type PerformPushRulesPutRequest struct { - UserID string `json:"user_id"` - RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` -} - -type QueryPushRulesRequest struct { - UserID string `json:"user_id"` -} - -type QueryPushRulesResponse struct { - RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` -} - type QueryNotificationsRequest struct { - Localpart string `json:"localpart"` // Required. - ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required. - From string `json:"from,omitempty"` - Limit int `json:"limit,omitempty"` - Only string `json:"only,omitempty"` + Localpart string `json:"localpart"` // Required. + ServerName spec.ServerName `json:"server_name"` // Required. + From string `json:"from,omitempty"` + Limit int `json:"limit,omitempty"` + Only string `json:"only,omitempty"` } type QueryNotificationsResponse struct { @@ -585,16 +580,16 @@ type QueryNotificationsResponse struct { } type Notification struct { - Actions []*pushrules.Action `json:"actions"` // Required. - Event synctypes.ClientEvent `json:"event"` // Required. - ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional. - Read bool `json:"read"` // Required. - RoomID string `json:"room_id"` // Required. - TS gomatrixserverlib.Timestamp `json:"ts"` // Required. + Actions []*pushrules.Action `json:"actions"` // Required. + Event synctypes.ClientEvent `json:"event"` // Required. + ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional. + Read bool `json:"read"` // Required. + RoomID string `json:"room_id"` // Required. + TS spec.Timestamp `json:"ts"` // Required. } type QueryNumericLocalpartRequest struct { - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryNumericLocalpartResponse struct { @@ -603,7 +598,7 @@ type QueryNumericLocalpartResponse struct { type QueryAccountAvailabilityRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryAccountAvailabilityResponse struct { @@ -612,7 +607,7 @@ type QueryAccountAvailabilityResponse struct { type QueryAccountByPasswordRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName PlaintextPassword string } @@ -627,12 +622,12 @@ type QueryLocalpartForThreePIDRequest struct { type QueryLocalpartForThreePIDResponse struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryThreePIDsForLocalpartRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryThreePIDsForLocalpartResponse struct { @@ -644,13 +639,13 @@ type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest type PerformSaveThreePIDAssociationRequest struct { ThreePID string Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName Medium string } type QueryAccountByLocalpartRequest struct { Localpart string - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName } type QueryAccountByLocalpartResponse struct { @@ -660,17 +655,17 @@ type QueryAccountByLocalpartResponse struct { // API functions required by the clientapi type ClientKeyAPI interface { UploadDeviceKeysAPI - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error + PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) // PerformClaimKeys claims one-time keys for use in pre-key messages - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error } type UploadDeviceKeysAPI interface { - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) } // API functions required by the syncapi @@ -682,10 +677,10 @@ type SyncKeyAPI interface { type FederationKeyAPI interface { UploadDeviceKeysAPI - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) } // KeyError is returned if there was a problem performing/querying the server @@ -960,6 +955,6 @@ type QuerySignaturesResponse struct { type PerformMarkAsStaleRequest struct { UserID string - Domain gomatrixserverlib.ServerName + Domain spec.ServerName DeviceID string } diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 51bd2753a1..ba72ff350c 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -18,6 +18,7 @@ import ( "context" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -38,7 +39,7 @@ type OutputReceiptEventConsumer struct { durable string topic string db storage.UserDatabase - serverName gomatrixserverlib.ServerName + serverName spec.ServerName syncProducer *producers.SyncAPI pgClient pushgateway.Client } @@ -104,7 +105,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return false } - updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) + updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(spec.AsTimestamp(metadata.Timestamp)), true) if err != nil { log.WithError(err).Error("userapi EDU consumer") return false diff --git a/userapi/consumers/devicelistupdate.go b/userapi/consumers/devicelistupdate.go index a65889fcc2..3389bb808d 100644 --- a/userapi/consumers/devicelistupdate.go +++ b/userapi/consumers/devicelistupdate.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -35,7 +36,7 @@ type DeviceListUpdateConsumer struct { durable string topic string updater *internal.DeviceListUpdater - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool } // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. @@ -72,7 +73,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M logrus.WithError(err).Errorf("Failed to read from device list update input topic") return true } - origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) + origin := spec.ServerName(msg.Header.Get("origin")) if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { return true } else if t.isLocalServerName(serverName) { diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 6704658df1..9cb9419d4f 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -13,6 +13,7 @@ import ( "github.com/tidwall/gjson" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -20,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushrules" rsapi "github.com/matrix-org/dendrite/roomserver/api" + rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -43,11 +45,11 @@ type OutputRoomEventConsumer struct { topic string pgClient pushgateway.Client syncProducer *producers.SyncAPI - msgCounts map[gomatrixserverlib.ServerName]userAPITypes.MessageStats - roomCounts map[gomatrixserverlib.ServerName]map[string]bool // map from serverName to map from rommID to "isEncrypted" + msgCounts map[spec.ServerName]userAPITypes.MessageStats + roomCounts map[spec.ServerName]map[string]bool // map from serverName to map from rommID to "isEncrypted" lastUpdate time.Time countsLock sync.Mutex - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } func NewOutputRoomEventConsumer( @@ -69,8 +71,8 @@ func NewOutputRoomEventConsumer( pgClient: pgClient, rsAPI: rsAPI, syncProducer: syncProducer, - msgCounts: map[gomatrixserverlib.ServerName]userAPITypes.MessageStats{}, - roomCounts: map[gomatrixserverlib.ServerName]map[string]bool{}, + msgCounts: map[spec.ServerName]userAPITypes.MessageStats{}, + roomCounts: map[spec.ServerName]map[string]bool{}, lastUpdate: time.Now(), countsLock: sync.Mutex{}, serverName: cfg.Matrix.ServerName, @@ -106,7 +108,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms } if s.cfg.Matrix.ReportStats.Enabled { - go s.storeMessageStats(ctx, event.Type(), event.Sender(), event.RoomID()) + go s.storeMessageStats(ctx, event.Type(), string(event.SenderID()), event.RoomID()) } log.WithFields(log.Fields{ @@ -119,7 +121,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return true } - if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil { + if err := s.processMessage(ctx, event, uint64(spec.AsTimestamp(metadata.Timestamp))); err != nil { log.WithFields(log.Fields{ "event_id": event.EventID(), }).WithError(err).Errorf("userapi consumer: process room event failure") @@ -210,7 +212,7 @@ func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoom return nil } -func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error { +func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName spec.ServerName) error { pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName) if err != nil { return fmt.Errorf("failed to query pushrules for user: %w", err) @@ -238,7 +240,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, } // updateMDirect copies the "is_direct" flag from oldRoomID to newROomID -func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error { +func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName, roomSize int) error { // this is most likely not a DM, so skip updating m.direct state if roomSize > 2 { return nil @@ -280,7 +282,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, return nil } -func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error { +func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName) error { tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag") if err != nil && !errors.Is(err, sql.ErrNoRows) { return err @@ -291,21 +293,39 @@ func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRo return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag) } -func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { +func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rstypes.HeaderedEvent, streamPos uint64) error { members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) if err != nil { return fmt.Errorf("s.localRoomMembers: %w", err) } switch { - case event.Type() == gomatrixserverlib.MRoomMember: - cevent := synctypes.HeaderedToClientEvent(event, synctypes.FormatAll) + case event.Type() == spec.MRoomMember: + sender := spec.UserID{} + validRoomID, roomErr := spec.NewRoomID(event.RoomID()) + if roomErr != nil { + return roomErr + } + userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + if queryErr == nil && userID != nil { + sender = *userID + } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) + if queryErr == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk) var member *localMembership member, err = newLocalMembership(&cevent) if err != nil { return fmt.Errorf("newLocalMembership: %w", err) } - if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName { + if member.Membership == spec.Invite && member.Domain == s.cfg.Matrix.ServerName { // localRoomMembers only adds joined members. An invite // should also be pushed to the target user. members = append(members, member) @@ -356,7 +376,7 @@ type localMembership struct { gomatrixserverlib.MemberContent UserID string Localpart string - Domain gomatrixserverlib.ServerName + Domain spec.ServerName } func newLocalMembership(event *synctypes.ClientEvent) (*localMembership, error) { @@ -418,7 +438,7 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s log.WithError(err).Errorf("Parsing MemberContent") continue } - if member.Membership != gomatrixserverlib.Join { + if member.Membership != spec.Join { continue } if member.Domain != s.cfg.Matrix.ServerName { @@ -435,8 +455,8 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s // looks it up in roomserver. If there is no name, // m.room.canonical_alias is consulted. Returns an empty string if the // room has no name. -func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { - if event.Type() == gomatrixserverlib.MRoomName { +func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *rstypes.HeaderedEvent) (string, error) { + if event.Type() == spec.MRoomName { name, err := unmarshalRoomName(event) if err != nil { return "", err @@ -461,7 +481,7 @@ func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixs return unmarshalRoomName(eventS) } - if event.Type() == gomatrixserverlib.MRoomCanonicalAlias { + if event.Type() == spec.MRoomCanonicalAlias { alias, err := unmarshalCanonicalAlias(event) if err != nil { return "", err @@ -480,11 +500,11 @@ func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixs } var ( - canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias} - roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName} + canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCanonicalAlias} + roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomName} ) -func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) { +func unmarshalRoomName(event *rstypes.HeaderedEvent) (string, error) { var nc eventutil.NameContent if err := json.Unmarshal(event.Content(), &nc); err != nil { return "", fmt.Errorf("unmarshaling NameContent: %w", err) @@ -493,7 +513,7 @@ func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) { return nc.Name, nil } -func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) { +func unmarshalCanonicalAlias(event *rstypes.HeaderedEvent) (string, error) { var cac eventutil.CanonicalAliasContent if err := json.Unmarshal(event.Content(), &cac); err != nil { return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err) @@ -503,7 +523,7 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er } // notifyLocal finds the right push actions for a local user, given an event. -func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { +func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) if err != nil { return fmt.Errorf("s.evaluatePushRules: %w", err) @@ -527,19 +547,37 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr return fmt.Errorf("s.localPushDevices: %w", err) } + sender := spec.UserID{} + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + + sk := event.StateKey() + if sk != nil && *sk != "" { + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) + if queryErr == nil && skUserID != nil { + skString := skUserID.String() + sk = &skString + } + } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.HeaderedToClientEvent(event, synctypes.FormatSync), + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender, sk), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it // to "work", but they only use a single device. ProfileTag: profileTag, RoomID: event.RoomID(), - TS: gomatrixserverlib.AsTimestamp(time.Now()), + TS: spec.AsTimestamp(time.Now()), } if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil { return fmt.Errorf("s.db.InsertNotification: %w", err) @@ -612,8 +650,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr // evaluatePushRules fetches and evaluates the push rules of a local // user. Returns actions (including dont_notify). -func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { - if event.Sender() == mem.UserID { +func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { + user := "" + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return nil, err + } + sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + if err == nil { + user = sender.String() + } + if user == mem.UserID { // SPEC: Homeservers MUST NOT notify the Push Gateway for // events that the user has sent themselves. return nil, nil @@ -630,9 +677,8 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * if err != nil { return nil, err } - sender := event.Sender() - if _, ok := ignored.List[sender]; ok { - return nil, fmt.Errorf("user %s is ignored", sender) + if _, ok := ignored.List[sender.String()]; ok { + return nil, fmt.Errorf("user %s is ignored", sender.String()) } } ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain) @@ -648,7 +694,9 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * roomSize: roomSize, } eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) - rule, err := eval.MatchEvent(event.Event) + rule, err := eval.MatchEvent(event.PDU, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) if err != nil { return nil, err } @@ -680,11 +728,11 @@ func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.Display func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } -func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { +func (rse *ruleSetEvalContext) HasPowerLevel(senderID spec.SenderID, levelKey string) (bool, error) { req := &rsapi.QueryLatestEventsAndStateRequest{ RoomID: rse.roomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomPowerLevels}, + {EventType: spec.MRoomPowerLevels}, }, } var res rsapi.QueryLatestEventsAndStateResponse @@ -692,22 +740,22 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err return false, err } for _, ev := range res.StateEvents { - if ev.Type() != gomatrixserverlib.MRoomPowerLevels { + if ev.Type() != spec.MRoomPowerLevels { continue } - plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event) + plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.PDU) if err != nil { return false, err } - return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil + return plc.UserLevel(senderID) >= plc.NotificationLevel(levelKey), nil } return true, nil } // localPushDevices pushes to the configured devices of a local // user. The map keys are [url][format]. -func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { +func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName spec.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db) if err != nil { return nil, "", fmt.Errorf("util.GetPushDevices: %w", err) @@ -731,7 +779,7 @@ func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpar } // notifyHTTP performs a notificatation to a Push Gateway. -func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { +func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { logger := log.WithFields(log.Fields{ "event_id": event.EventID(), "url": url, @@ -754,6 +802,15 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri } default: + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return nil, err + } + sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + if err != nil { + logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID()) + return nil, err + } req = pushgateway.NotifyRequest{ Notification: pushgateway.Notification{ Content: event.Content(), @@ -765,14 +822,30 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri ID: event.EventID(), RoomID: event.RoomID(), RoomName: roomName, - Sender: event.Sender(), + Sender: sender.String(), Type: event.Type(), }, } - if mem, err := event.Membership(); err == nil { + if mem, memberErr := event.Membership(); memberErr == nil { req.Notification.Membership = mem } - if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) { + userID, err := spec.NewUserID(fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName), true) + if err != nil { + logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart) + return nil, err + } + roomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + logger.WithError(err).Errorf("event roomID is invalid %s", event.RoomID()) + return nil, err + } + + localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) + if err != nil { + logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID()) + return nil, err + } + if event.StateKey() != nil && *event.StateKey() == string(localSender) { req.Notification.UserIsTarget = true } } @@ -805,7 +878,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri } // deleteRejectedPushers deletes the pushers associated with the given devices. -func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) { +func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName spec.ServerName) { log.WithFields(log.Fields{ "localpart": localpart, "app_id0": devices[0].AppID, diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 4827ad47c2..4dc81e74aa 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -8,10 +8,13 @@ import ( "time" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/internal/pushrules" + rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi/storage" @@ -33,13 +36,19 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, } } -func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent { +func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent { t.Helper() - ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10) + ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV10).NewEventFromTrustedJSON([]byte(content), false) if err != nil { t.Fatalf("failed to create event: %v", err) } - return ev.Headered(gomatrixserverlib.RoomVersionV10) + return &types.HeaderedEvent{PDU: ev} +} + +type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } + +func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func Test_evaluatePushRules(t *testing.T) { @@ -48,7 +57,7 @@ func Test_evaluatePushRules(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - consumer := OutputRoomEventConsumer{db: db} + consumer := OutputRoomEventConsumer{db: db, rsAPI: &FakeUserRoomserverAPI{}} testCases := []struct { name string @@ -59,13 +68,13 @@ func Test_evaluatePushRules(t *testing.T) { }{ { name: "m.receipt doesn't notify", - eventContent: `{"type":"m.receipt"}`, + eventContent: `{"type":"m.receipt","room_id":"!room:example.com"}`, wantAction: pushrules.UnknownAction, wantActions: nil, }, { name: "m.reaction doesn't notify", - eventContent: `{"type":"m.reaction"}`, + eventContent: `{"type":"m.reaction","room_id":"!room:example.com"}`, wantAction: pushrules.DontNotifyAction, wantActions: []*pushrules.Action{ { @@ -75,7 +84,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message notifies", - eventContent: `{"type":"m.room.message"}`, + eventContent: `{"type":"m.room.message","room_id":"!room:example.com"}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ @@ -84,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message highlights", - eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`, + eventContent: `{"type":"m.room.message", "content": {"body": "test"},"room_id":"!room:example.com"}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ @@ -139,9 +148,9 @@ func TestMessageStats(t *testing.T) { tests := []struct { name string args args - ourServer gomatrixserverlib.ServerName + ourServer spec.ServerName lastUpdate time.Time - initRoomCounts map[gomatrixserverlib.ServerName]map[string]bool + initRoomCounts map[spec.ServerName]map[string]bool wantStats userAPITypes.MessageStats }{ { @@ -197,7 +206,7 @@ func TestMessageStats(t *testing.T) { name: "day change creates a new room map", ourServer: "localhost", lastUpdate: time.Now().Add(-time.Hour * 24), - initRoomCounts: map[gomatrixserverlib.ServerName]map[string]bool{ + initRoomCounts: map[spec.ServerName]map[string]bool{ "localhost": {"encryptedRoom": true}, }, args: args{ @@ -219,11 +228,11 @@ func TestMessageStats(t *testing.T) { tt.lastUpdate = time.Now() } if tt.initRoomCounts == nil { - tt.initRoomCounts = map[gomatrixserverlib.ServerName]map[string]bool{} + tt.initRoomCounts = map[spec.ServerName]map[string]bool{} } s := &OutputRoomEventConsumer{ db: db, - msgCounts: map[gomatrixserverlib.ServerName]userAPITypes.MessageStats{}, + msgCounts: map[spec.ServerName]userAPITypes.MessageStats{}, roomCounts: tt.initRoomCounts, countsLock: sync.Mutex{}, lastUpdate: tt.lastUpdate, diff --git a/userapi/consumers/signingkeyupdate.go b/userapi/consumers/signingkeyupdate.go index 006ccb728f..9de866343f 100644 --- a/userapi/consumers/signingkeyupdate.go +++ b/userapi/consumers/signingkeyupdate.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -37,7 +38,7 @@ type SigningKeyUpdateConsumer struct { topic string userAPI api.UploadDeviceKeysAPI cfg *config.UserAPI - isLocalServerName func(gomatrixserverlib.ServerName) bool + isLocalServerName func(spec.ServerName) bool } // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. @@ -75,7 +76,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M logrus.WithError(err).Errorf("Failed to read from signing key update input topic") return true } - origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) + origin := spec.ServerName(msg.Header.Get("origin")) if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { logrus.WithError(err).Error("failed to split user id") return true @@ -99,10 +100,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M UserID: updatePayload.UserID, } uploadRes := &api.PerformUploadDeviceKeysResponse{} - if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { - logrus.WithError(err).Error("failed to upload device keys") - return false - } + t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) if uploadRes.Error != nil { logrus.WithError(uploadRes.Error).Error("failed to upload device keys") return true diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go index 23b6207e23..be05841c4e 100644 --- a/userapi/internal/cross_signing.go +++ b/userapi/internal/cross_signing.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "golang.org/x/crypto/curve25519" ) @@ -104,7 +105,7 @@ func sanityCheckKey(key fclient.CrossSigningKey, userID string, purpose fclient. } // nolint:gocyclo -func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { +func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { // Find the keys to store. byPurpose := map[fclient.CrossSigningKeyPurpose]fclient.CrossSigningKey{} toStore := types.CrossSigningKeyMap{} @@ -116,7 +117,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. Err: "Master key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return nil + return } byPurpose[fclient.CrossSigningKeyPurposeMaster] = req.MasterKey @@ -132,7 +133,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. Err: "Self-signing key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return nil + return } byPurpose[fclient.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey @@ -147,7 +148,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. Err: "User-signing key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return nil + return } byPurpose[fclient.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey @@ -162,7 +163,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. Err: "No keys were supplied in the request", IsMissingParam: true, } - return nil + return } // We can't have a self-signing or user-signing key without a master @@ -175,7 +176,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. res.Error = &api.KeyError{ Err: "Retrieving cross-signing keys from database failed: " + err.Error(), } - return nil + return } // If we still can't find a master key for the user then stop the upload. @@ -186,7 +187,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. Err: "No master key was found", IsMissingParam: true, } - return nil + return } } @@ -213,7 +214,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. } } if !changed { - return nil + return } // Store the keys. @@ -221,7 +222,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), } - return nil + return } // Now upload any signatures that were included with the keys. @@ -239,7 +240,7 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), } - return nil + return } } } @@ -256,18 +257,16 @@ func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api. update.SelfSigningKey = &ssk } if update.MasterKey == nil && update.SelfSigningKey == nil { - return nil + return } if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } - return nil } - return nil } -func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { +func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) { // Before we do anything, we need the master and self-signing keys for this user. // Then we can verify the signatures make sense. queryReq := &api.QueryKeysRequest{ @@ -278,7 +277,7 @@ func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req for userID := range req.Signatures { queryReq.UserToDevices[userID] = []string{} } - _ = a.QueryKeys(ctx, queryReq, queryRes) + a.QueryKeys(ctx, queryReq, queryRes) selfSignatures := map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice{} otherSignatures := map[string]map[gomatrixserverlib.KeyID]fclient.CrossSigningForKeyOrDevice{} @@ -324,14 +323,14 @@ func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req res.Error = &api.KeyError{ Err: fmt.Sprintf("a.processSelfSignatures: %s", err), } - return nil + return } if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.processOtherSignatures: %s", err), } - return nil + return } // Finally, generate a notification that we updated the signatures. @@ -347,10 +346,9 @@ func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } - return nil + return } } - return nil } func (a *UserInternalAPI) processSelfSignatures( @@ -485,12 +483,12 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase( continue } - appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) { + appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature spec.Base64Bytes) { if key.Signatures == nil { key.Signatures = types.CrossSigningSigMap{} } if _, ok := key.Signatures[originUserID]; !ok { - key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes) + key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]spec.Base64Bytes) } key.Signatures[originUserID][originKeyID] = signature } @@ -523,7 +521,7 @@ func (a *UserInternalAPI) crossSigningKeysFromDatabase( } } -func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { +func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) { for targetUserID, forTargetUser := range req.TargetIDs { keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) if err != nil && err != sql.ErrNoRows { @@ -562,7 +560,7 @@ func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySig res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), } - return nil + return } for sourceUserID, forSourceUser := range sigMap { @@ -577,12 +575,11 @@ func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySig res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{} } if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok { - res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig } } } } - return nil } diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go index a274e1ae3b..3fccf56bb5 100644 --- a/userapi/internal/device_list_update.go +++ b/userapi/internal/device_list_update.go @@ -26,6 +26,7 @@ import ( rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" @@ -98,8 +99,8 @@ type DeviceListUpdater struct { api DeviceListUpdaterAPI producer KeyChangeProducer fedClient fedsenderapi.KeyserverFederationAPI - workerChans []chan gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName + workerChans []chan spec.ServerName + thisServer spec.ServerName // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // block on or timeout via a select. @@ -113,7 +114,7 @@ type DeviceListUpdater struct { type DeviceListUpdaterDatabase interface { // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) // MarkDeviceListStale sets the stale bit for this user to isStale. MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error @@ -133,7 +134,7 @@ type DeviceListUpdaterDatabase interface { } type DeviceListUpdaterAPI interface { - PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error + PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) } // KeyChangeProducer is the interface for producers.KeyChange useful for testing. @@ -146,7 +147,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, - rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, + rsAPI rsapi.KeyserverRoomserverAPI, thisServer spec.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -157,7 +158,7 @@ func NewDeviceListUpdater( producer: producer, fedClient: fedClient, thisServer: thisServer, - workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), + workerChans: make([]chan spec.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, rsAPI: rsAPI, @@ -170,12 +171,12 @@ func (u *DeviceListUpdater) Start() error { // Allocate a small buffer per channel. // If the buffer limit is reached, backpressure will cause the processing of EDUs // to stop (in this transaction) until key requests can be made. - ch := make(chan gomatrixserverlib.ServerName, 10) + ch := make(chan spec.ServerName, 10) u.workerChans[i] = ch go u.worker(ch) } - staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []spec.ServerName{}) if err != nil { return err } @@ -195,7 +196,7 @@ func (u *DeviceListUpdater) Start() error { // CleanUp removes stale device entries for users we don't share a room with anymore func (u *DeviceListUpdater) CleanUp() error { - staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []spec.ServerName{}) if err != nil { return err } @@ -223,7 +224,7 @@ func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { // ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it. // Blocks until the device list is synced or the timeout is reached. -func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error { +func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName spec.ServerName, userID string) error { mu := u.mutex(userID) mu.Lock() err := u.db.MarkDeviceListStale(ctx, userID, true) @@ -369,12 +370,12 @@ func (u *DeviceListUpdater) clearChannel(userID string) { } } -func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { - retries := make(map[gomatrixserverlib.ServerName]time.Time) +func (u *DeviceListUpdater) worker(ch chan spec.ServerName) { + retries := make(map[spec.ServerName]time.Time) retriesMu := &sync.Mutex{} // restarter goroutine which will inject failed servers into ch when it is time go func() { - var serversToRetry []gomatrixserverlib.ServerName + var serversToRetry []spec.ServerName for { serversToRetry = serversToRetry[:0] // reuse memory time.Sleep(time.Second) @@ -413,7 +414,7 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { } } -func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { +func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Duration, bool) { ctx := u.process.Context() logger := util.GetLogger(ctx).WithField("server_name", serverName) deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() @@ -421,7 +422,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam waitTime := defaultWaitTime // How long should we wait to try again? successCount := 0 // How many user requests failed? - userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) + userIDs, err := u.db.StaleDeviceLists(ctx, []spec.ServerName{serverName}) if err != nil { logger.WithError(err).Error("Failed to load stale device lists") return waitTime, true @@ -457,7 +458,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam return waitTime, !allUsersSucceeded } -func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) { +func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName spec.ServerName, userID string) (time.Duration, error) { ctx, cancel := context.WithTimeout(ctx, requestTimeout) defer cancel() logger := util.GetLogger(ctx).WithFields(logrus.Fields{ @@ -518,7 +519,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go uploadReq.SelfSigningKey = *res.SelfSigningKey } } - _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) } err = u.updateDeviceList(&res) if err != nil { diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go index 47b31c6850..10b9c6521f 100644 --- a/userapi/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -65,7 +66,7 @@ func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Conte // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. -func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { +func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) { d.mu.Lock() defer d.mu.Unlock() var result []string @@ -124,8 +125,7 @@ func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys type mockDeviceListUpdaterAPI struct { } -func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { - return nil +func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { } type roundTripper struct { @@ -136,18 +136,16 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return t.fn(req) } -func newFedClient(tripper func(*http.Request) (*http.Response, error)) *fclient.FederationClient { +func newFedClient(tripper func(*http.Request) (*http.Response, error)) fclient.FederationClient { _, pkey, _ := ed25519.GenerateKey(nil) fedClient := fclient.NewFederationClient( []*fclient.SigningIdentity{ { - ServerName: gomatrixserverlib.ServerName("example.test"), + ServerName: spec.ServerName("example.test"), KeyID: gomatrixserverlib.KeyID("ed25519:test"), PrivateKey: pkey, }, }, - ) - fedClient.Client = *fclient.NewClient( fclient.WithTransport(&roundTripper{tripper}), ) return fedClient @@ -294,7 +292,7 @@ func TestDebounce(t *testing.T) { ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} fedCh := make(chan *http.Response, 1) - srv := gomatrixserverlib.ServerName("example.com") + srv := spec.ServerName("example.com") userID := "@alice:example.com" keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` incomingFedReq := make(chan struct{}) @@ -414,7 +412,7 @@ func TestDeviceListUpdater_CleanUp(t *testing.T) { } // check that we still have Alice in our stale list - staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + staleUsers, err := db.StaleDeviceLists(ctx, []spec.ServerName{"test"}) if err != nil { t.Error(err) } diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 043028725f..786a2dcd89 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -62,7 +63,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor return nil } -func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { +func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) { res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) // wrap request map in a top-level by-domain map @@ -80,7 +81,7 @@ func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.Perform domainToDeviceKeys[string(serverName)] = nested } for domain, local := range domainToDeviceKeys { - if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if !a.Config.Matrix.IsLocalServerName(spec.ServerName(domain)) { continue } // claim local keys @@ -109,7 +110,6 @@ func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.Perform if len(domainToDeviceKeys) > 0 { a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) } - return nil } func (a *UserInternalAPI) claimRemoteKeys( @@ -129,7 +129,7 @@ func (a *UserInternalAPI) claimRemoteKeys( defer cancel() defer wg.Done() - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, spec.ServerName(domain), keysToClaim) mu.Lock() defer mu.Unlock() @@ -227,7 +227,7 @@ func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *a } // nolint:gocyclo -func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { +func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { var respMu sync.Mutex res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]fclient.CrossSigningKey) @@ -251,7 +251,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), } - return nil + return } // pull out display names after we have the keys so we handle wildcards correctly @@ -321,7 +321,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque for targetUserID, masterKey := range res.MasterKeys { if masterKey.Signatures == nil { - masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } for targetKeyID := range masterKey.Keys { sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) @@ -329,7 +329,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return nil + return } logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue @@ -340,7 +340,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque for sourceUserID, forSourceUser := range sigMap { for sourceKeyID, sourceSig := range forSourceUser { if _, ok := masterKey.Signatures[sourceUserID]; !ok { - masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig } @@ -355,7 +355,7 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return nil + return } logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") continue @@ -368,12 +368,12 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque continue } if deviceKey.Signatures == nil { - deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } for sourceUserID, forSourceUser := range sigMap { for sourceKeyID, sourceSig := range forSourceUser { if _, ok := deviceKey.Signatures[sourceUserID]; !ok { - deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig } @@ -383,7 +383,6 @@ func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReque } } } - return nil } func (a *UserInternalAPI) remoteKeysFromDatabase( @@ -424,13 +423,13 @@ func (a *UserInternalAPI) queryRemoteKeys( domains := map[string]struct{}{} for domain := range domainToDeviceKeys { - if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if a.Config.Matrix.IsLocalServerName(spec.ServerName(domain)) { continue } domains[domain] = struct{}{} } for domain := range domainToCrossSigningKeys { - if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + if a.Config.Matrix.IsLocalServerName(spec.ServerName(domain)) { continue } domains[domain] = struct{}{} @@ -514,7 +513,7 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer( } } for userID := range userIDsForAllDevices { - err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID) + err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID) if err != nil { logrus.WithFields(logrus.Fields{ logrus.ErrorKey: err, @@ -542,7 +541,7 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer( if len(devKeys) == 0 { return } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, spec.ServerName(serverName), devKeys) if err == nil { resultCh <- &queryKeysResp return @@ -671,7 +670,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe } else { // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID) if err != nil { continue // ignore invalid users diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 826369cb82..367d1ccc95 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -26,11 +26,14 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushgateway" @@ -61,6 +64,37 @@ type UserInternalAPI struct { Updater *DeviceListUpdater } +func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) { + exists, err := a.DB.RegistrationTokenExists(ctx, *registrationToken.Token) + if err != nil { + return false, err + } + if exists { + return false, fmt.Errorf("token: %s already exists", *registrationToken.Token) + } + _, err = a.DB.InsertRegistrationToken(ctx, registrationToken) + if err != nil { + return false, fmt.Errorf("Error creating token: %s"+err.Error(), *registrationToken.Token) + } + return true, nil +} + +func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) { + return a.DB.ListRegistrationTokens(ctx, returnAll, valid) +} + +func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { + return a.DB.GetRegistrationToken(ctx, tokenString) +} + +func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error { + return a.DB.DeleteRegistrationToken(ctx, tokenString) +} + +func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) { + return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes) +} + func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { @@ -112,7 +146,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } - deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) + deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(spec.AsTimestamp(time.Now()))) if err != nil { logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") return err @@ -170,8 +204,8 @@ func addUserToRoom( UserID: userID, Content: addGroupContent, } - joinRes := rsapi.PerformJoinResponse{} - return rsAPI.PerformJoin(ctx, &joinReq, &joinRes) + _, _, err := rsAPI.PerformJoin(ctx, &joinReq) + return err } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { @@ -627,22 +661,17 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a return fmt.Errorf("server name %q not locally configured", serverName) } - evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), - } - evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} - if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { - return err - } - if err := evacuateRes.Error; err != nil { - logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation") + userID := fmt.Sprintf("@%s:%s", req.Localpart, serverName) + _, err := a.RSAPI.PerformAdminEvacuateUser(ctx, userID) + if err != nil { + logrus.WithError(err).WithField("userID", userID).Errorf("Failed to evacuate user after account deactivation") } deviceReq := &api.PerformDeviceDeletionRequest{ UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } deviceRes := &api.PerformDeviceDeletionResponse{} - if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { + if err = a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { return err } @@ -697,62 +726,43 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp return nil } -func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error { - // Delete metadata - if req.DeleteBackup { - if req.Version == "" { - res.BadInput = true - res.Error = "must specify a version to delete" - return nil - } - exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) - if err != nil { - res.Error = fmt.Sprintf("failed to delete backup: %s", err) - } - res.Exists = exists - res.Version = req.Version - return nil - } +func (a *UserInternalAPI) DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error) { + return a.DB.DeleteKeyBackup(ctx, userID, version) +} + +func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest) (string, error) { // Create metadata - if req.Version == "" { - version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) - if err != nil { - res.Error = fmt.Sprintf("failed to create backup: %s", err) - } - res.Exists = err == nil - res.Version = version - return nil - } + return a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) +} + +func (a *UserInternalAPI) UpdateBackupKeyAuthData(ctx context.Context, req *api.PerformKeyBackupRequest) (*api.PerformKeyBackupResponse, error) { + res := &api.PerformKeyBackupResponse{} // Update metadata if len(req.Keys.Rooms) == 0 { err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) - if err != nil { - res.Error = fmt.Sprintf("failed to update backup: %s", err) - } res.Exists = err == nil res.Version = req.Version - return nil + if err != nil { + return res, fmt.Errorf("failed to update backup: %w", err) + } + return res, nil } // Upload Keys for a specific version metadata - a.uploadBackupKeys(ctx, req, res) - return nil + return a.uploadBackupKeys(ctx, req) } -func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { +func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest) (*api.PerformKeyBackupResponse, error) { + res := &api.PerformKeyBackupResponse{} // you can only upload keys for the CURRENT version version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "") if err != nil { - res.Error = fmt.Sprintf("failed to query version: %s", err) - return + return res, fmt.Errorf("failed to query version: %w", err) } if deleted { - res.Error = "backup was deleted" - return + return res, fmt.Errorf("backup was deleted") } if version != req.Version { - res.BadInput = true - res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version) - return + return res, spec.WrongBackupVersionError(version) } res.Exists = true res.Version = version @@ -770,23 +780,25 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform } count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads) if err != nil { - res.Error = fmt.Sprintf("failed to upsert keys: %s", err) - return + return res, fmt.Errorf("failed to upsert keys: %w", err) } res.KeyCount = count res.KeyETag = etag + return res, nil } -func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error { +func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest) (*api.QueryKeyBackupResponse, error) { + res := &api.QueryKeyBackupResponse{} version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) res.Version = version if err != nil { - if err == sql.ErrNoRows { - res.Exists = false - return nil + if errors.Is(err, sql.ErrNoRows) { + return res, nil } - res.Error = fmt.Sprintf("failed to query key backup: %s", err) - return nil + if errors.Is(err, strconv.ErrSyntax) { + return res, nil + } + return res, fmt.Errorf("failed to query key backup: %s", err) } res.Algorithm = algorithm res.AuthData = authData @@ -796,18 +808,17 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB if !req.ReturnKeys { res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID) if err != nil { - res.Error = fmt.Sprintf("failed to count keys: %s", err) + return res, fmt.Errorf("failed to count keys: %w", err) } - return nil + return res, nil } result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) if err != nil { - res.Error = fmt.Sprintf("failed to query keys: %s", err) - return nil + return res, fmt.Errorf("failed to query keys: %s", err) } res.Keys = result - return nil + return res, nil } func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { @@ -888,39 +899,31 @@ func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPusher func (a *UserInternalAPI) PerformPushRulesPut( ctx context.Context, - req *api.PerformPushRulesPutRequest, - _ *struct{}, + userID string, + ruleSets *pushrules.AccountRuleSets, ) error { - bs, err := json.Marshal(&req.RuleSets) + bs, err := json.Marshal(ruleSets) if err != nil { return err } userReq := api.InputAccountDataRequest{ - UserID: req.UserID, + UserID: userID, DataType: pushRulesAccountDataType, AccountData: json.RawMessage(bs), } var userRes api.InputAccountDataResponse // empty - if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { - return err - } - return nil + return a.InputAccountData(ctx, &userReq, &userRes) } -func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { - localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) +func (a *UserInternalAPI) QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) { + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { - return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) + return nil, fmt.Errorf("failed to split user ID %q for push rules", userID) } - pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain) - if err != nil { - return fmt.Errorf("failed to query push rules: %w", err) - } - res.RuleSets = pushRules - return nil + return a.DB.QueryPushRules(ctx, localpart, domain) } -func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) { +func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error) { return a.DB.SetAvatarURL(ctx, localpart, serverName, avatarURL) } @@ -955,7 +958,7 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } } -func (a *UserInternalAPI) SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) { +func (a *UserInternalAPI) SetDisplayName(ctx context.Context, localpart string, serverName spec.ServerName, displayName string) (*authtypes.Profile, bool, error) { return a.DB.SetDisplayName(ctx, localpart, serverName, displayName) } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 4ffb126a7a..125b315853 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -21,7 +21,9 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/userapi/api" @@ -29,41 +31,50 @@ import ( "github.com/matrix-org/dendrite/userapi/types" ) +type RegistrationTokens interface { + RegistrationTokenExists(ctx context.Context, token string) (bool, error) + InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) + ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) + DeleteRegistrationToken(ctx context.Context, tokenString string) error + UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) +} + type Profile interface { - GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) + GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) + SetAvatarURL(ctx context.Context, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, serverName spec.ServerName, displayName string) (*authtypes.Profile, bool, error) } type Account interface { // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) - GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error) - GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) - GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) - DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) - SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error + CreateAccount(ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword string) (*api.Account, error) + GetNewNumericLocalpart(ctx context.Context, serverName spec.ServerName) (int64, error) + CheckAccountAvailability(ctx context.Context, localpart string, serverName spec.ServerName) (bool, error) + GetAccountByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*api.Account, error) + DeactivateAccount(ctx context.Context, localpart string, serverName spec.ServerName) (err error) + SetPassword(ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword string) error } type AccountData interface { - SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error - GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) + SaveAccountData(ctx context.Context, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error + GetAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given // localpart, room ID and type. // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval - GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error) - QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error) + GetAccountDataByType(ctx context.Context, localpart string, serverName spec.ServerName, roomID, dataType string) (data json.RawMessage, err error) + QueryPushRules(ctx context.Context, localpart string, serverName spec.ServerName) (*pushrules.AccountRuleSets, error) } type Device interface { GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) - GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error) + GetDeviceByID(ctx context.Context, localpart string, serverName spec.ServerName, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) ([]api.Device, error) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked @@ -71,12 +82,12 @@ type Device interface { // an error will be returned. // If no device ID is given one is generated. // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) - UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error - RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error + CreateDevice(ctx context.Context, localpart string, serverName spec.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart string, serverName spec.ServerName, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error + RemoveDevices(ctx context.Context, localpart string, serverName spec.ServerName, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error) + RemoveAllDevices(ctx context.Context, localpart string, serverName spec.ServerName, exceptDeviceID string) (devices []api.Device, err error) } type KeyBackup interface { @@ -108,26 +119,26 @@ type OpenID interface { } type Pusher interface { - UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error - GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error) - RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error + UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName spec.ServerName) error + GetPushers(ctx context.Context, localpart string, serverName spec.ServerName) ([]api.Pusher, error) + RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName spec.ServerName) error RemovePushers(ctx context.Context, appid, pushkey string) error } type ThreePID interface { - SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error) + SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName spec.ServerName, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) - GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) - GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName spec.ServerName, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error) } type Notification interface { - InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error - DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) - SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error) - GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) - GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) - GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) + InsertNotification(ctx context.Context, localpart string, serverName spec.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart string, serverName spec.ServerName, roomID string, pos uint64, read bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, serverName spec.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, serverName spec.ServerName, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, _ error) DeleteOldNotifications(ctx context.Context) error } @@ -143,6 +154,7 @@ type UserDatabase interface { Pusher Statistics ThreePID + RegistrationTokens } type KeyChangeDatabase interface { @@ -199,7 +211,7 @@ type KeyDatabase interface { // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) // MarkDeviceListStale sets the stale bit for this user to isStale. MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error @@ -209,7 +221,7 @@ type KeyDatabase interface { CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error - StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature spec.Base64Bytes) error DeleteStaleDeviceLists( ctx context.Context, @@ -219,8 +231,8 @@ type KeyDatabase interface { type Statistics interface { UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) - DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) - UpsertDailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error + DailyRoomsMessages(ctx context.Context, serverName spec.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) + UpsertDailyRoomsMessages(ctx context.Context, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 057160374d..6ffda340eb 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -22,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const accountDataSchema = ` @@ -74,7 +74,7 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) @@ -90,7 +90,7 @@ func (s *accountDataStatements) InsertAccountData( func (s *accountDataStatements) SelectAccountData( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, @@ -129,7 +129,7 @@ func (s *accountDataStatements) SelectAccountData( func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 9c46249a70..bb97545d03 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -20,13 +20,12 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -79,10 +78,10 @@ type accountsStatements struct { selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { +func NewPostgresAccountsTable(db *sql.DB, serverName spec.ServerName) (tables.AccountsTable, error) { s := &accountsStatements{ serverName: serverName, } @@ -127,7 +126,7 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam // on success. func (s *accountsStatements) InsertAccount( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -153,7 +152,7 @@ func (s *accountsStatements) InsertAccount( } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) @@ -161,21 +160,21 @@ func (s *accountsStatements) UpdatePassword( } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account @@ -197,7 +196,7 @@ func (s *accountsStatements) SelectAccountByLocalpart( } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index b6fe6d7210..138b629d71 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -23,8 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) var crossSigningKeysSchema = ` @@ -76,7 +76,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( r = types.CrossSigningKeyMap{} for rows.Next() { var keyTypeInt int16 - var keyData gomatrixserverlib.Base64Bytes + var keyData spec.Base64Bytes if err := rows.Scan(&keyTypeInt, &keyData); err != nil { return nil, err } @@ -90,7 +90,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( } func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, ) error { keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] if !ok { diff --git a/userapi/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go index b0117145c6..61a3811841 100644 --- a/userapi/storage/postgres/cross_signing_sigs_table.go +++ b/userapi/storage/postgres/cross_signing_sigs_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) var crossSigningSigsSchema = ` @@ -96,12 +97,12 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( for rows.Next() { var userID string var keyID gomatrixserverlib.KeyID - var signature gomatrixserverlib.Base64Bytes + var signature spec.Base64Bytes if err := rows.Scan(&userID, &keyID, &signature); err != nil { return nil, err } if _, ok := r[userID]; !ok { - r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + r[userID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } r[userID][keyID] = signature } @@ -112,7 +113,7 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, + signature spec.Base64Bytes, ) error { if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) diff --git a/userapi/storage/postgres/deltas/2022110411000000_server_names.go b/userapi/storage/postgres/deltas/2022110411000000_server_names.go index 375f775bec..b3fe2d43c0 100644 --- a/userapi/storage/postgres/deltas/2022110411000000_server_names.go +++ b/userapi/storage/postgres/deltas/2022110411000000_server_names.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) var serverNamesTables = []string{ @@ -43,7 +43,7 @@ var serverNamesDropIndex = []string{ // PostgreSQL doesn't expect the table name to be specified as a substituted // argument in that way so it results in a syntax error in the query. -func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName spec.ServerName) error { for _, table := range serverNamesTables { q := fmt.Sprintf( "ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';", diff --git a/userapi/storage/postgres/deltas/2022110411000001_server_names.go b/userapi/storage/postgres/deltas/2022110411000001_server_names.go index 04a47fa7ba..f83859dfae 100644 --- a/userapi/storage/postgres/deltas/2022110411000001_server_names.go +++ b/userapi/storage/postgres/deltas/2022110411000001_server_names.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // I know what you're thinking: you're wondering "why doesn't this use $1 @@ -14,7 +14,7 @@ import ( // PostgreSQL doesn't expect the table name to be specified as a substituted // argument in that way so it results in a syntax error in the query. -func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName spec.ServerName) error { for _, table := range serverNamesTables { q := fmt.Sprintf( "UPDATE %s SET server_name = %s WHERE server_name = '';", diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 88f8839c58..0335f82665 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -27,7 +27,7 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const devicesSchema = ` @@ -112,10 +112,10 @@ type devicesStatements struct { deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt deleteDevicesStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { +func NewPostgresDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.DevicesTable, error) { s := &devicesStatements{ serverName: serverName, } @@ -151,7 +151,7 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName // Returns the device on success. func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id string, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -176,7 +176,7 @@ func (s *devicesStatements) InsertDevice( } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64, ) (*api.Device, error) { @@ -186,7 +186,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * // deleteDevice removes a single device by id and user localpart. func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id string, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) _, err := stmt.ExecContext(ctx, id, localpart, serverName) @@ -197,7 +197,7 @@ func (s *devicesStatements) DeleteDevice( // Returns an error if the execution failed. func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) @@ -209,7 +209,7 @@ func (s *devicesStatements) DeleteDevices( // given user localpart. func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -219,7 +219,7 @@ func (s *devicesStatements) DeleteDevicesByLocalpart( func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -232,7 +232,7 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName stmt := s.selectDeviceByTokenStmt err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { @@ -246,7 +246,7 @@ func (s *devicesStatements) SelectDeviceByToken( // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -279,7 +279,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName var lastseents sql.NullInt64 var displayName sql.NullString for rows.Next() { @@ -300,7 +300,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -342,7 +342,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index a1cff2c46a..f589362586 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -20,13 +20,13 @@ import ( "encoding/json" "time" - "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib/spec" ) type notificationsStatements struct { @@ -112,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -128,7 +128,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err @@ -142,7 +142,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err @@ -155,7 +155,7 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { @@ -168,7 +168,7 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local for rows.Next() { var id int64 var roomID string - var ts gomatrixserverlib.Timestamp + var ts spec.Timestamp var read bool var jsonStr string err = rows.Scan( @@ -198,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, filter tables.NotificationFilter) (count int64, err error) { err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, err error) { err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 68d87f007f..345877d11c 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -34,10 +34,10 @@ const selectOpenIDTokenSQL = "" + type openIDTokenStatements struct { insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { +func NewPostgresOpenIDTable(db *sql.DB, serverName spec.ServerName) (tables.OpenIDTable, error) { s := &openIDTokenStatements{ serverName: serverName, } @@ -56,7 +56,7 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, serverName gomatrixserverlib.ServerName, + token, localpart string, serverName spec.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) @@ -72,7 +72,7 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index df4e0db63e..e404c32f29 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const profilesSchema = ` @@ -92,7 +92,7 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (err error) { _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return @@ -100,7 +100,7 @@ func (s *profilesStatements) InsertProfile( func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( @@ -114,7 +114,7 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ @@ -130,7 +130,7 @@ func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index e255406b9b..d943a18392 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -25,7 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -95,7 +95,7 @@ type pushersStatements struct { func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err @@ -103,7 +103,7 @@ func (s *pushersStatements) InsertPusher( func (s *pushersStatements) SelectPushers( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName) @@ -144,7 +144,7 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( ctx context.Context, txn *sql.Tx, appid, pushkey, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go new file mode 100644 index 0000000000..3c3e3fdd93 --- /dev/null +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -0,0 +1,222 @@ +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/api" + internal "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "golang.org/x/exp/constraints" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS userapi_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +const selectTokenSQL = "" + + "SELECT token FROM userapi_registration_tokens WHERE token = $1" + +const insertTokenSQL = "" + + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" + +const listAllTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens" + +const listValidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" + + "(expiry_time > $1 OR expiry_time IS NULL)" + +const listInvalidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed <= pending + completed OR expiry_time <= $1)" + +const getTokenSQL = "" + + "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1" + +const deleteTokenSQL = "" + + "DELETE FROM userapi_registration_tokens WHERE token = $1" + +const updateTokenUsesAllowedAndExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1" + +const updateTokenUsesAllowedSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1" + +const updateTokenExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1" + +type registrationTokenStatements struct { + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt + updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt + updateTokenUsesAllowedStatement *sql.Stmt + updateTokenExpiryTimeStatement *sql.Stmt +} + +func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { + s := ®istrationTokenStatements{} + _, err := db.Exec(registrationTokensSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectTokenStatement, selectTokenSQL}, + {&s.insertTokenStatement, insertTokenSQL}, + {&s.listAllTokensStatement, listAllTokensSQL}, + {&s.listValidTokensStatement, listValidTokensSQL}, + {&s.listInvalidTokenStatement, listInvalidTokensSQL}, + {&s.getTokenStatement, getTokenSQL}, + {&s.deleteTokenStatement, deleteTokenSQL}, + {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL}, + {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL}, + {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL}, + }.Prepare(db) +} + +func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + var existingToken string + stmt := sqlutil.TxStmt(tx, s.selectTokenStatement) + err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) { + stmt := sqlutil.TxStmt(tx, s.insertTokenStatement) + _, err := stmt.ExecContext( + ctx, + *registrationToken.Token, + getInsertValue(registrationToken.UsesAllowed), + getInsertValue(registrationToken.ExpiryTime), + *registrationToken.Pending, + *registrationToken.Completed) + if err != nil { + return false, err + } + return true, nil +} + +func getInsertValue[t constraints.Integer](in *t) any { + if in == nil { + return nil + } + return *in +} + +func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { + var stmt *sql.Stmt + var tokens []api.RegistrationToken + var tokenString string + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + var rows *sql.Rows + var err error + if returnAll { + stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement) + rows, err = stmt.QueryContext(ctx) + } else if valid { + stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } else { + stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } + if err != nil { + return tokens, err + } + defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed") + for rows.Next() { + err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return tokens, err + } + tokenString := tokenString + pending := pending + completed := completed + usesAllowed := usesAllowed + expiryTime := expiryTime + + tokenMap := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + tokens = append(tokens, tokenMap) + } + return tokens, rows.Err() +} + +func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { + stmt := sqlutil.TxStmt(tx, s.getTokenStatement) + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return nil, err + } + token := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + return &token, nil +} + +func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { + stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement) + _, err := stmt.ExecContext(ctx, tokenString) + if err != nil { + return err + } + return nil +} + +func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) { + var stmt *sql.Stmt + usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"] + expiryTime, expiryTimePresent := newAttributes["expiryTime"] + if usesAllowedPresent && expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime) + if err != nil { + return nil, err + } + } else if usesAllowedPresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed) + if err != nil { + return nil, err + } + } else if expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, expiryTime) + if err != nil { + return nil, err + } + } + return s.GetRegistrationToken(ctx, tx, tokenString) +} diff --git a/userapi/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go index c823b58c6a..e2086dc996 100644 --- a/userapi/storage/postgres/stale_device_lists.go +++ b/userapi/storage/postgres/stale_device_lists.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -81,11 +82,11 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, spec.AsTimestamp(time.Now())) return err } -func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) { // we only query for 1 domain or all domains so optimise for those use cases if len(domains) == 0 { rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) diff --git a/userapi/storage/postgres/stats_table.go b/userapi/storage/postgres/stats_table.go index f62467fa49..a7949e4bab 100644 --- a/userapi/storage/postgres/stats_table.go +++ b/userapi/storage/postgres/stats_table.go @@ -20,7 +20,7 @@ import ( "time" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" @@ -191,7 +191,7 @@ ON CONFLICT (localpart, device_id, timestamp) DO NOTHING const queryDBEngineVersion = "SHOW server_version;" type statsStatements struct { - serverName gomatrixserverlib.ServerName + serverName spec.ServerName lastUpdate time.Time countUsersLastSeenAfterStmt *sql.Stmt countR30UsersStmt *sql.Stmt @@ -204,7 +204,7 @@ type statsStatements struct { selectDailyMessagesStmt *sql.Stmt } -func NewPostgresStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.StatsTable, error) { +func NewPostgresStatsTable(db *sql.DB, serverName spec.ServerName) (tables.StatsTable, error) { s := &statsStatements{ serverName: serverName, lastUpdate: time.Now(), @@ -280,7 +280,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) int64(api.AccountTypeAppService), }, api.AccountTypeGuest, - gomatrixserverlib.AsTimestamp(registeredAfter), + spec.AsTimestamp(registeredAfter), ) if err != nil { return nil, err @@ -304,7 +304,7 @@ func (s *statsStatements) dailyUsers(ctx context.Context, txn *sql.Tx) (result i stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) lastSeenAfter := time.Now().AddDate(0, 0, -1) err = stmt.QueryRowContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), ).Scan(&result) return } @@ -313,7 +313,7 @@ func (s *statsStatements) monthlyUsers(ctx context.Context, txn *sql.Tx) (result stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) lastSeenAfter := time.Now().AddDate(0, 0, -30) err = stmt.QueryRowContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), ).Scan(&result) return } @@ -330,7 +330,7 @@ func (s *statsStatements) r30Users(ctx context.Context, txn *sql.Tx) (map[string diff := time.Hour * 24 * 30 rows, err := stmt.QueryContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), diff.Milliseconds(), ) if err != nil { @@ -367,8 +367,8 @@ func (s *statsStatements) r30UsersV2(ctx context.Context, txn *sql.Tx) (map[stri tomorrow := time.Now().Add(time.Hour * 24) rows, err := stmt.QueryContext(ctx, - gomatrixserverlib.AsTimestamp(sixtyDaysAgo), - gomatrixserverlib.AsTimestamp(tomorrow), + spec.AsTimestamp(sixtyDaysAgo), + spec.AsTimestamp(tomorrow), diff.Milliseconds(), ) if err != nil { @@ -464,9 +464,9 @@ func (s *statsStatements) UpdateUserDailyVisits( startTime = startTime.AddDate(0, 0, -1) } _, err := stmt.ExecContext(ctx, - gomatrixserverlib.AsTimestamp(startTime), - gomatrixserverlib.AsTimestamp(lastUpdate), - gomatrixserverlib.AsTimestamp(time.Now()), + spec.AsTimestamp(startTime), + spec.AsTimestamp(lastUpdate), + spec.AsTimestamp(time.Now()), ) if err == nil { s.lastUpdate = time.Now() @@ -476,13 +476,13 @@ func (s *statsStatements) UpdateUserDailyVisits( func (s *statsStatements) UpsertDailyStats( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, stats types.MessageStats, + serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64, ) error { stmt := sqlutil.TxStmt(txn, s.upsertMessagesStmt) timestamp := time.Now().Truncate(time.Hour * 24) _, err := stmt.ExecContext(ctx, - gomatrixserverlib.AsTimestamp(timestamp), + spec.AsTimestamp(timestamp), serverName, stats.Messages, stats.SentMessages, stats.MessagesE2EE, stats.SentMessagesE2EE, activeRooms, activeE2EERooms, @@ -492,12 +492,12 @@ func (s *statsStatements) UpsertDailyStats( func (s *statsStatements) DailyRoomsMessages( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { stmt := sqlutil.TxStmt(txn, s.selectDailyMessagesStmt) timestamp := time.Now().Truncate(time.Hour * 24) - err = stmt.QueryRowContext(ctx, serverName, gomatrixserverlib.AsTimestamp(timestamp)). + err = stmt.QueryRowContext(ctx, serverName, spec.AsTimestamp(timestamp)). Scan(&msgStats.Messages, &msgStats.SentMessages, &msgStats.MessagesE2EE, &msgStats.SentMessagesE2EE, &activeRooms, &activeE2EERooms) if err != nil && err != sql.ErrNoRows { return msgStats, 0, 0, err diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 3769c16054..2481fe67b1 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -20,19 +20,18 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/shared" + "github.com/matrix-org/gomatrixserverlib/spec" // Import the postgres database driver. _ "github.com/lib/pq" ) // NewDatabase creates a new accounts and profiles database -func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, serverName spec.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err @@ -58,6 +57,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties * return nil, err } + registationTokensTable, err := NewPostgresRegistrationTokensTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err) + } accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) @@ -130,6 +133,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties * ThreePIDs: threePIDTable, Pushers: pusherTable, Notifications: notificationsTable, + RegistrationTokens: registationTokensTable, Stats: statsTable, ServerName: serverName, DB: db, diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index f41c431223..15b42a0a6a 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -77,7 +77,7 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, serverName gomatrixserverlib.ServerName, err error) { +) (localpart string, serverName spec.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { @@ -88,7 +88,7 @@ func (s *threepidStatements) SelectLocalpartForThreePID( func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { @@ -113,7 +113,7 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 742cdccfb7..da9572969b 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -28,8 +28,10 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "golang.org/x/crypto/bcrypt" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -42,6 +44,7 @@ import ( type Database struct { DB *sql.DB Writer sqlutil.Writer + RegistrationTokens tables.RegistrationTokensTable Accounts tables.AccountsTable Profiles tables.ProfileTable AccountDatas tables.AccountDataTable @@ -55,7 +58,7 @@ type Database struct { Pushers tables.PusherTable Stats tables.StatsTable LoginTokenLifetime time.Duration - ServerName gomatrixserverlib.ServerName + ServerName spec.ServerName BcryptCost int OpenIDTokenLifetimeMS int64 } @@ -77,10 +80,46 @@ const ( loginTokenByteLength = 32 ) +func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) { + return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token) +} + +func (d *Database) InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (created bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, registrationToken) + return err + }) + return +} + +func (d *Database) ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) { + return d.RegistrationTokens.ListRegistrationTokens(ctx, nil, returnAll, valid) +} + +func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { + return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString) +} + +func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) + return err + }) + return +} + +func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes) + return err + }) + return +} + // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword string, ) (*api.Account, error) { hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName) @@ -100,7 +139,7 @@ func (d *Database) GetAccountByPassword( // Returns sql.ErrNoRows if no profile exists which matches the given localpart. func (d *Database) GetProfileByLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (*authtypes.Profile, error) { return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName) } @@ -109,7 +148,7 @@ func (d *Database) GetProfileByLocalpart( // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, avatarURL string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -123,7 +162,7 @@ func (d *Database) SetAvatarURL( // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, displayName string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -135,7 +174,7 @@ func (d *Database) SetDisplayName( // SetPassword sets the account password to the given hash. func (d *Database) SetPassword( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword string, ) error { hash, err := d.hashPassword(plaintextPassword) @@ -151,7 +190,7 @@ func (d *Database) SetPassword( // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, plaintextPassword, appserviceID string, accountType api.AccountType, ) (acc *api.Account, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -176,7 +215,7 @@ func (d *Database) CreateAccount( // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, plaintextPassword, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { var err error @@ -208,7 +247,7 @@ func (d *Database) createAccount( func (d *Database) QueryPushRules( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (*pushrules.AccountRuleSets, error) { data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules") if err != nil { @@ -247,7 +286,7 @@ func (d *Database) QueryPushRules( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -258,7 +297,7 @@ func (d *Database) SaveAccountData( // GetAccountData returns account data related to a given localpart // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ( +func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName spec.ServerName) ( global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error, @@ -271,7 +310,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string, serverN // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, roomID, dataType string, ) (data json.RawMessage, err error) { return d.AccountDatas.SelectAccountDataByType( @@ -281,7 +320,7 @@ func (d *Database) GetAccountDataByType( // GetNewNumericLocalpart generates and returns a new unused numeric localpart func (d *Database) GetNewNumericLocalpart( - ctx context.Context, serverName gomatrixserverlib.ServerName, + ctx context.Context, serverName spec.ServerName, ) (int64, error) { return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName) } @@ -301,7 +340,7 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use") // Returns an error if there was a problem talking to the database. func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid string, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, medium string, ) (err error) { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -339,7 +378,7 @@ func (d *Database) RemoveThreePIDAssociation( // Returns an error if there was a problem talking to the database. func (d *Database) GetLocalpartForThreePID( ctx context.Context, threepid string, medium string, -) (localpart string, serverName gomatrixserverlib.ServerName, err error) { +) (localpart string, serverName spec.ServerName, err error) { return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) } @@ -349,7 +388,7 @@ func (d *Database) GetLocalpartForThreePID( // Returns an error if there was an issue talking to the database. func (d *Database) GetThreePIDsForLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (threepids []authtypes.ThreePID, err error) { return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName) } @@ -357,7 +396,7 @@ func (d *Database) GetThreePIDsForLocalpart( // CheckAccountAvailability checks if the username/localpart is already present // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) { +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName spec.ServerName) (bool, error) { _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) if err == sql.ErrNoRows { return true, nil @@ -368,7 +407,7 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // GetAccountByLocalpart returns the account associated with the given localpart. // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName, ) (*api.Account, error) { // try to get the account with lowercase localpart (majority) acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName) @@ -386,7 +425,7 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi } // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) { +func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName spec.ServerName) (err error) { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Accounts.DeactivateAccount(ctx, localpart, serverName) }) @@ -571,7 +610,7 @@ func (d *Database) GetDeviceByAccessToken( // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByID( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, ) (*api.Device, error) { return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID) @@ -580,7 +619,7 @@ func (d *Database) GetDeviceByID( // GetDevicesByLocalpart returns the devices matching the given localpart. func (d *Database) GetDevicesByLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) ([]api.Device, error) { return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "") } @@ -596,7 +635,7 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap // If no device ID is given one is generated. // Returns the device on success. func (d *Database) CreateDevice( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string, ) (dev *api.Device, returnErr error) { if deviceID != nil && *deviceID != "" { @@ -674,7 +713,7 @@ func generateDeviceID() (string, error) { // Returns SQL error if there are problems and nil on success. func (d *Database) UpdateDevice( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, displayName *string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -688,7 +727,7 @@ func (d *Database) UpdateDevice( // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveDevices( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, devices []string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -704,7 +743,7 @@ func (d *Database) RemoveDevices( // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, exceptDeviceID string, ) (devices []api.Device, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -721,7 +760,7 @@ func (d *Database) RemoveAllDevices( } // UpdateDeviceLastSeen updates a last seen timestamp and the ip address. -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent) }) @@ -771,13 +810,13 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( return d.LoginTokens.SelectLoginToken(ctx, token) } -func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { +func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName spec.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) }) } -func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) { +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos) return err @@ -785,7 +824,7 @@ func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string return } -func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) { +func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName spec.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b) return err @@ -793,15 +832,15 @@ func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, s return } -func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { +func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName spec.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter) } -func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) { +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName spec.ServerName, filter tables.NotificationFilter) (int64, error) { return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter) } -func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) { +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, _ error) { return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID) } @@ -813,7 +852,7 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error { func (d *Database) UpsertPusher( ctx context.Context, p api.Pusher, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { data, err := json.Marshal(p.Data) if err != nil { @@ -839,7 +878,7 @@ func (d *Database) UpsertPusher( // GetPushers returns the pushers matching the given localpart. func (d *Database) GetPushers( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) ([]api.Pusher, error) { return d.Pushers.SelectPushers(ctx, nil, localpart, serverName) } @@ -848,7 +887,7 @@ func (d *Database) GetPushers( // Invoked when `append` is true and `kind` is null in // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set func (d *Database) RemovePusher( - ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, appid, pushkey, localpart string, serverName spec.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName) @@ -875,14 +914,14 @@ func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, * return d.Stats.UserStatistics(ctx, nil) } -func (d *Database) UpsertDailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error { +func (d *Database) UpsertDailyRoomsMessages(ctx context.Context, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Stats.UpsertDailyStats(ctx, txn, serverName, stats, activeRooms, activeE2EERooms) }) } func (d *Database) DailyRoomsMessages( - ctx context.Context, serverName gomatrixserverlib.ServerName, + ctx context.Context, serverName spec.ServerName, ) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { return d.Stats.DailyRoomsMessages(ctx, nil, serverName) } @@ -995,7 +1034,7 @@ func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64 // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. -func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { +func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) { return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) } @@ -1037,7 +1076,7 @@ func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string result := fclient.CrossSigningKey{ UserID: userID, Usage: []fclient.CrossSigningKeyPurpose{purpose}, - Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{ keyID: key, }, } @@ -1050,10 +1089,10 @@ func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string continue } if result.Signatures == nil { - result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + result.Signatures = map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } if _, ok := result.Signatures[sigUserID]; !ok { - result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } for sigKeyID, sigBytes := range forSigUserID { result.Signatures[sigUserID][sigKeyID] = sigBytes @@ -1091,7 +1130,7 @@ func (d *KeyDatabase) StoreCrossSigningSigsForTarget( ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, + signature spec.Base64Bytes, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index 2fbdc57327..3a6367c450 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -21,7 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const accountDataSchema = ` @@ -76,7 +76,7 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage, ) error { _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content) @@ -85,7 +85,7 @@ func (s *accountDataStatements) InsertAccountData( func (s *accountDataStatements) SelectAccountData( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, @@ -123,7 +123,7 @@ func (s *accountDataStatements) SelectAccountData( func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index f4ebe2158e..d01915a764 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -19,13 +19,12 @@ import ( "database/sql" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -79,10 +78,10 @@ type accountsStatements struct { selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { +func NewSQLiteAccountsTable(db *sql.DB, serverName spec.ServerName) (tables.AccountsTable, error) { s := &accountsStatements{ db: db, serverName: serverName, @@ -122,7 +121,7 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -148,7 +147,7 @@ func (s *accountsStatements) InsertAccount( } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) @@ -156,21 +155,21 @@ func (s *accountsStatements) UpdatePassword( } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + ctx context.Context, localpart string, serverName spec.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account @@ -192,7 +191,7 @@ func (s *accountsStatements) SelectAccountByLocalpart( } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, + ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index e1c45c4116..5c2ce70397 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -23,8 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) var crossSigningKeysSchema = ` @@ -75,7 +75,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( r = types.CrossSigningKeyMap{} for rows.Next() { var keyTypeInt int16 - var keyData gomatrixserverlib.Base64Bytes + var keyData spec.Base64Bytes if err := rows.Scan(&keyTypeInt, &keyData); err != nil { return nil, err } @@ -89,7 +89,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( } func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, + ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes, ) error { keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] if !ok { diff --git a/userapi/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go index 2be00c9c11..6572641158 100644 --- a/userapi/storage/sqlite3/cross_signing_sigs_table.go +++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) var crossSigningSigsSchema = ` @@ -94,12 +95,12 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( for rows.Next() { var userID string var keyID gomatrixserverlib.KeyID - var signature gomatrixserverlib.Base64Bytes + var signature spec.Base64Bytes if err := rows.Scan(&userID, &keyID, &signature); err != nil { return nil, err } if _, ok := r[userID]; !ok { - r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + r[userID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{} } r[userID][keyID] = signature } @@ -110,7 +111,7 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, + signature spec.Base64Bytes, ) error { if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) diff --git a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go index c11ea68445..76f39a9086 100644 --- a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go +++ b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -42,7 +42,7 @@ var serverNamesDropIndex = []string{ // PostgreSQL doesn't expect the table name to be specified as a substituted // argument in that way so it results in a syntax error in the query. -func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName spec.ServerName) error { for _, table := range serverNamesTables { q := fmt.Sprintf( "SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;", diff --git a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go index 04a47fa7ba..f83859dfae 100644 --- a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go +++ b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/lib/pq" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // I know what you're thinking: you're wondering "why doesn't this use $1 @@ -14,7 +14,7 @@ import ( // PostgreSQL doesn't expect the table name to be specified as a substituted // argument in that way so it results in a syntax error in the query. -func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName spec.ServerName) error { for _, table := range serverNamesTables { q := fmt.Sprintf( "UPDATE %s SET server_name = %s WHERE server_name = '';", diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 65e17527df..23e8231168 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -25,9 +25,9 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/gomatrixserverlib" ) const devicesSchema = ` @@ -97,10 +97,10 @@ type devicesStatements struct { updateDeviceLastSeenStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { +func NewSQLiteDevicesTable(db *sql.DB, serverName spec.ServerName) (tables.DevicesTable, error) { s := &devicesStatements{ db: db, serverName: serverName, @@ -137,7 +137,7 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // Returns the device on success. func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id string, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -167,7 +167,7 @@ func (s *devicesStatements) InsertDevice( } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64, ) (*api.Device, error) { @@ -193,7 +193,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id string, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) _, err := stmt.ExecContext(ctx, id, localpart, serverName) @@ -202,7 +202,7 @@ func (s *devicesStatements) DeleteDevice( func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, devices []string, ) error { orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1) @@ -224,7 +224,7 @@ func (s *devicesStatements) DeleteDevices( func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -234,7 +234,7 @@ func (s *devicesStatements) DeleteDevicesByLocalpart( func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -247,7 +247,7 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName stmt := s.selectDeviceByTokenStmt err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { @@ -261,7 +261,7 @@ func (s *devicesStatements) SelectDeviceByToken( // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -287,7 +287,7 @@ func (s *devicesStatements) SelectDeviceByID( func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -343,7 +343,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName var displayName sql.NullString var lastseents sql.NullInt64 for rows.Next() { @@ -362,7 +362,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index 049fcbd06a..fbe916153d 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -20,13 +20,13 @@ import ( "encoding/json" "time" - "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib/spec" ) type notificationsStatements struct { @@ -112,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -128,7 +128,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err @@ -142,7 +142,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err @@ -155,7 +155,7 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { @@ -168,7 +168,7 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local for rows.Next() { var id int64 var roomID string - var ts gomatrixserverlib.Timestamp + var ts spec.Timestamp var read bool var jsonStr string err = rows.Scan( @@ -198,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, filter tables.NotificationFilter) (count int64, err error) { err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, err error) { err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index f064297411..def0074d2b 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -35,10 +35,10 @@ type openIDTokenStatements struct { db *sql.DB insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + serverName spec.ServerName } -func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { +func NewSQLiteOpenIDTable(db *sql.DB, serverName spec.ServerName) (tables.OpenIDTable, error) { s := &openIDTokenStatements{ db: db, serverName: serverName, @@ -58,7 +58,7 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) ( func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, serverName gomatrixserverlib.ServerName, + token, localpart string, serverName spec.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) @@ -74,7 +74,7 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes var localpart string - var serverName gomatrixserverlib.ServerName + var serverName spec.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 867026d7af..a20d7e848f 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) const profilesSchema = ` @@ -88,7 +88,7 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return err @@ -96,7 +96,7 @@ func (s *profilesStatements) InsertProfile( func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( @@ -110,7 +110,7 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ @@ -132,7 +132,7 @@ func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index c9d451dc52..e09f9c78fc 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -25,7 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -98,7 +98,7 @@ type pushersStatements struct { func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err @@ -106,7 +106,7 @@ func (s *pushersStatements) InsertPusher( func (s *pushersStatements) SelectPushers( ctx context.Context, txn *sql.Tx, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName) @@ -147,7 +147,7 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( ctx context.Context, txn *sql.Tx, appid, pushkey, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) error { _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err diff --git a/userapi/storage/sqlite3/registration_tokens_table.go b/userapi/storage/sqlite3/registration_tokens_table.go new file mode 100644 index 0000000000..8979547317 --- /dev/null +++ b/userapi/storage/sqlite3/registration_tokens_table.go @@ -0,0 +1,222 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/api" + internal "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "golang.org/x/exp/constraints" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS userapi_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +const selectTokenSQL = "" + + "SELECT token FROM userapi_registration_tokens WHERE token = $1" + +const insertTokenSQL = "" + + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" + +const listAllTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens" + +const listValidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" + + "(expiry_time > $1 OR expiry_time IS NULL)" + +const listInvalidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed <= pending + completed OR expiry_time <= $1)" + +const getTokenSQL = "" + + "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1" + +const deleteTokenSQL = "" + + "DELETE FROM userapi_registration_tokens WHERE token = $1" + +const updateTokenUsesAllowedAndExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1" + +const updateTokenUsesAllowedSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1" + +const updateTokenExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1" + +type registrationTokenStatements struct { + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt + updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt + updateTokenUsesAllowedStatement *sql.Stmt + updateTokenExpiryTimeStatement *sql.Stmt +} + +func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { + s := ®istrationTokenStatements{} + _, err := db.Exec(registrationTokensSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectTokenStatement, selectTokenSQL}, + {&s.insertTokenStatement, insertTokenSQL}, + {&s.listAllTokensStatement, listAllTokensSQL}, + {&s.listValidTokensStatement, listValidTokensSQL}, + {&s.listInvalidTokenStatement, listInvalidTokensSQL}, + {&s.getTokenStatement, getTokenSQL}, + {&s.deleteTokenStatement, deleteTokenSQL}, + {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL}, + {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL}, + {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL}, + }.Prepare(db) +} + +func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + var existingToken string + stmt := sqlutil.TxStmt(tx, s.selectTokenStatement) + err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) { + stmt := sqlutil.TxStmt(tx, s.insertTokenStatement) + _, err := stmt.ExecContext( + ctx, + *registrationToken.Token, + getInsertValue(registrationToken.UsesAllowed), + getInsertValue(registrationToken.ExpiryTime), + *registrationToken.Pending, + *registrationToken.Completed) + if err != nil { + return false, err + } + return true, nil +} + +func getInsertValue[t constraints.Integer](in *t) any { + if in == nil { + return nil + } + return *in +} + +func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { + var stmt *sql.Stmt + var tokens []api.RegistrationToken + var tokenString string + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + var rows *sql.Rows + var err error + if returnAll { + stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement) + rows, err = stmt.QueryContext(ctx) + } else if valid { + stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } else { + stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } + if err != nil { + return tokens, err + } + defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed") + for rows.Next() { + err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return tokens, err + } + tokenString := tokenString + pending := pending + completed := completed + usesAllowed := usesAllowed + expiryTime := expiryTime + + tokenMap := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + tokens = append(tokens, tokenMap) + } + return tokens, rows.Err() +} + +func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { + stmt := sqlutil.TxStmt(tx, s.getTokenStatement) + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return nil, err + } + token := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + return &token, nil +} + +func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { + stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement) + _, err := stmt.ExecContext(ctx, tokenString) + if err != nil { + return err + } + return nil +} + +func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) { + var stmt *sql.Stmt + usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"] + expiryTime, expiryTimePresent := newAttributes["expiryTime"] + if usesAllowedPresent && expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime) + if err != nil { + return nil, err + } + } else if usesAllowedPresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed) + if err != nil { + return nil, err + } + } else if expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, expiryTime) + if err != nil { + return nil, err + } + } + return s.GetRegistrationToken(ctx, tx, tokenString) +} diff --git a/userapi/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go index f078fc99fe..5302899f4b 100644 --- a/userapi/storage/sqlite3/stale_device_lists.go +++ b/userapi/storage/sqlite3/stale_device_lists.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -83,11 +84,11 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, spec.AsTimestamp(time.Now())) return err } -func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) { // we only query for 1 domain or all domains so optimise for those use cases if len(domains) == 0 { rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index 72b3ba49d1..71d80d4d41 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -20,7 +20,7 @@ import ( "strings" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" @@ -195,7 +195,7 @@ ON CONFLICT (localpart, device_id, timestamp) DO NOTHING const queryDBEngineVersion = "select sqlite_version();" type statsStatements struct { - serverName gomatrixserverlib.ServerName + serverName spec.ServerName db *sql.DB lastUpdate time.Time countUsersLastSeenAfterStmt *sql.Stmt @@ -209,7 +209,7 @@ type statsStatements struct { selectDailyMessagesStmt *sql.Stmt } -func NewSQLiteStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.StatsTable, error) { +func NewSQLiteStatsTable(db *sql.DB, serverName spec.ServerName) (tables.StatsTable, error) { s := &statsStatements{ serverName: serverName, lastUpdate: time.Now(), @@ -298,8 +298,8 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) params[i] = v // i: 0 1 2 => ($1, $2, $3) params[i+1+len(nonGuests)] = v // i: 4 5 6 => ($5, $6, $7) } - params[3] = api.AccountTypeGuest // $4 - params[7] = gomatrixserverlib.AsTimestamp(registeredAfter) // $8 + params[3] = api.AccountTypeGuest // $4 + params[7] = spec.AsTimestamp(registeredAfter) // $8 rows, err := stmt.QueryContext(ctx, params...) if err != nil { @@ -324,7 +324,7 @@ func (s *statsStatements) dailyUsers(ctx context.Context, txn *sql.Tx) (result i stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) lastSeenAfter := time.Now().AddDate(0, 0, -1) err = stmt.QueryRowContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), ).Scan(&result) return } @@ -333,7 +333,7 @@ func (s *statsStatements) monthlyUsers(ctx context.Context, txn *sql.Tx) (result stmt := sqlutil.TxStmt(txn, s.countUsersLastSeenAfterStmt) lastSeenAfter := time.Now().AddDate(0, 0, -30) err = stmt.QueryRowContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), ).Scan(&result) return } @@ -348,8 +348,8 @@ func (s *statsStatements) r30Users(ctx context.Context, txn *sql.Tx) (map[string diff := time.Hour * 24 * 30 rows, err := stmt.QueryContext(ctx, - gomatrixserverlib.AsTimestamp(lastSeenAfter), - gomatrixserverlib.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), + spec.AsTimestamp(lastSeenAfter), diff.Milliseconds(), ) if err != nil { @@ -386,8 +386,8 @@ func (s *statsStatements) r30UsersV2(ctx context.Context, txn *sql.Tx) (map[stri tomorrow := time.Now().Add(time.Hour * 24) rows, err := stmt.QueryContext(ctx, - gomatrixserverlib.AsTimestamp(sixtyDaysAgo), - gomatrixserverlib.AsTimestamp(tomorrow), + spec.AsTimestamp(sixtyDaysAgo), + spec.AsTimestamp(tomorrow), diff.Milliseconds(), ) if err != nil { @@ -482,9 +482,9 @@ func (s *statsStatements) UpdateUserDailyVisits( startTime = startTime.AddDate(0, 0, -1) } _, err := stmt.ExecContext(ctx, - gomatrixserverlib.AsTimestamp(startTime), - gomatrixserverlib.AsTimestamp(lastUpdate), - gomatrixserverlib.AsTimestamp(time.Now()), + spec.AsTimestamp(startTime), + spec.AsTimestamp(lastUpdate), + spec.AsTimestamp(time.Now()), ) if err == nil { s.lastUpdate = time.Now() @@ -494,13 +494,13 @@ func (s *statsStatements) UpdateUserDailyVisits( func (s *statsStatements) UpsertDailyStats( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, stats types.MessageStats, + serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64, ) error { stmt := sqlutil.TxStmt(txn, s.upsertMessagesStmt) timestamp := time.Now().Truncate(time.Hour * 24) _, err := stmt.ExecContext(ctx, - gomatrixserverlib.AsTimestamp(timestamp), + spec.AsTimestamp(timestamp), serverName, stats.Messages, stats.SentMessages, stats.MessagesE2EE, stats.SentMessagesE2EE, activeRooms, activeE2EERooms, @@ -510,12 +510,12 @@ func (s *statsStatements) UpsertDailyStats( func (s *statsStatements) DailyRoomsMessages( ctx context.Context, txn *sql.Tx, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, ) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { stmt := sqlutil.TxStmt(txn, s.selectDailyMessagesStmt) timestamp := time.Now().Truncate(time.Hour * 24) - err = stmt.QueryRowContext(ctx, serverName, gomatrixserverlib.AsTimestamp(timestamp)). + err = stmt.QueryRowContext(ctx, serverName, spec.AsTimestamp(timestamp)). Scan(&msgStats.Messages, &msgStats.SentMessages, &msgStats.MessagesE2EE, &msgStats.SentMessagesE2EE, &activeRooms, &activeE2EERooms) if err != nil && err != sql.ErrNoRows { return msgStats, 0, 0, err diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 3742eebada..48f5c842bf 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -20,17 +20,16 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/userapi/storage/shared" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) // NewUserDatabase creates a new accounts and profiles database -func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { +func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, serverName spec.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err @@ -51,7 +50,10 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti if err = m.Up(ctx); err != nil { return nil, err } - + registationTokensTable, err := NewSQLiteRegistrationTokensTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteRegistrationsTokenTable: %w", err) + } accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) @@ -131,6 +133,7 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti LoginTokenLifetime: loginTokenLifetime, BcryptCost: bcryptCost, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + RegistrationTokens: registationTokensTable, }, nil } diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index 2db7d5887e..a83f804238 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -21,7 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -81,7 +81,7 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, serverName gomatrixserverlib.ServerName, err error) { +) (localpart string, serverName spec.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { @@ -92,7 +92,7 @@ func (s *threepidStatements) SelectLocalpartForThreePID( func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { @@ -117,7 +117,7 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, - localpart string, serverName gomatrixserverlib.ServerName, + localpart string, serverName spec.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index 6981765f99..39231b2240 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -23,7 +23,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/postgres" @@ -36,7 +36,7 @@ func NewUserDatabase( ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index fac1a29d85..87913e0f82 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" "golang.org/x/crypto/bcrypt" @@ -533,11 +534,11 @@ func Test_Notification(t *testing.T) { {}, }, Event: synctypes.ClientEvent{ - Content: gomatrixserverlib.RawJSON("{}"), + Content: spec.RawJSON("{}"), }, Read: false, RoomID: roomID, - TS: gomatrixserverlib.AsTimestamp(ts), + TS: spec.AsTimestamp(ts), } err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 19e5f23c63..cbadd98e94 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -29,7 +29,7 @@ func NewUserDatabase( ctx context.Context, conMan sqlutil.Connections, dbProperties *config.DatabaseOptions, - serverName gomatrixserverlib.ServerName, + serverName spec.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 2d1339282c..3a0be73e4a 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -23,38 +23,49 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/types" ) +type RegistrationTokensTable interface { + RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error) + InsertRegistrationToken(ctx context.Context, txn *sql.Tx, registrationToken *clientapi.RegistrationToken) (bool, error) + ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error) + DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error + UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) +} + type AccountDataTable interface { - InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error - SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) - SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error) + InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error + SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) + SelectAccountDataByType(ctx context.Context, localpart string, serverName spec.ServerName, roomID, dataType string) (data json.RawMessage, err error) } type AccountsTable interface { - InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) - UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error) - DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) - SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error) - SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) - SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error) + InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + UpdatePassword(ctx context.Context, localpart string, serverName spec.ServerName, passwordHash string) (err error) + DeactivateAccount(ctx context.Context, localpart string, serverName spec.ServerName) (err error) + SelectPasswordHash(ctx context.Context, localpart string, serverName spec.ServerName) (hash string, err error) + SelectAccountByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*api.Account, error) + SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (id int64, err error) } type DevicesTable interface { - InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) - InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64) (*api.Device, error) - DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error - DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error - DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error - UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error + InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, localpart string, serverName spec.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64) (*api.Device, error) + DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName spec.ServerName) error + DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, devices []string) error + DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, exceptDeviceID string) error + UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, deviceID string, displayName *string) error SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) - SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) - SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error) + SelectDeviceByID(ctx context.Context, localpart string, serverName spec.ServerName, deviceID string) (*api.Device, error) + SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, exceptDeviceID string) ([]api.Device, error) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error } type KeyBackupTable interface { @@ -81,47 +92,47 @@ type LoginTokenTable interface { } type OpenIDTable interface { - InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error) + InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName spec.ServerName, expiresAtMS int64) (err error) SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) } type ProfileTable interface { - InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error - SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName) error + SelectProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, displayName string) (*authtypes.Profile, bool, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } type ThreePIDTable interface { - SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) - SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) - InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error) + SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName spec.ServerName, err error) + SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (threepids []authtypes.ThreePID, err error) + InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName spec.ServerName) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) } type PusherTable interface { - InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error - SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error) - DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName spec.ServerName) error + SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName) ([]api.Pusher, error) + DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName spec.ServerName) error DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error } type NotificationTable interface { Clean(ctx context.Context, txn *sql.Tx) error - Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error - DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) - UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) - Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) - SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error) - SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) + Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) + Select(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, _ error) } type StatsTable interface { UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error) - DailyRoomsMessages(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) + DailyRoomsMessages(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) UpdateUserDailyVisits(ctx context.Context, txn *sql.Tx, startTime, lastUpdate time.Time) error - UpsertDailyStats(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error + UpsertDailyStats(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error } type NotificationFilter uint32 @@ -176,17 +187,17 @@ type KeyChanges interface { type StaleDeviceLists interface { InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error - SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error } type CrossSigningKeys interface { SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) - UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType fclient.CrossSigningKeyPurpose, keyData spec.Base64Bytes) error } type CrossSigningSigs interface { SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) - UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature spec.Base64Bytes) error DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error } diff --git a/userapi/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go index b9bdafdaab..09924eb084 100644 --- a/userapi/storage/tables/stale_device_lists_test.go +++ b/userapi/storage/tables/stale_device_lists_test.go @@ -6,7 +6,7 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/postgres" "github.com/matrix-org/dendrite/userapi/storage/sqlite3" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -57,7 +57,7 @@ func TestStaleDeviceLists(t *testing.T) { // Query one server wantStaleUsers := []string{alice.ID, bob.ID} - gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []spec.ServerName{"test"}) if err != nil { t.Fatalf("failed to query stale device lists: %s", err) } @@ -67,7 +67,7 @@ func TestStaleDeviceLists(t *testing.T) { // Query all servers wantStaleUsers = []string{alice.ID, bob.ID, charlie} - gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []spec.ServerName{}) if err != nil { t.Fatalf("failed to query stale device lists: %s", err) } @@ -82,7 +82,7 @@ func TestStaleDeviceLists(t *testing.T) { } // Verify we don't get anything back after deleting - gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []spec.ServerName{"test"}) if err != nil { t.Fatalf("failed to query stale device lists: %s", err) } diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index 969bc5303e..61fe026636 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -79,7 +79,7 @@ func mustMakeAccountAndDevice( accDB tables.AccountsTable, devDB tables.DevicesTable, localpart string, - serverName gomatrixserverlib.ServerName, // nolint:unparam + serverName spec.ServerName, // nolint:unparam accType api.AccountType, userAgent string, ) { @@ -108,7 +108,7 @@ func mustUpdateDeviceLastSeen( timestamp time.Time, ) { t.Helper() - _, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + _, err := db.ExecContext(ctx, "UPDATE userapi_devices SET last_seen_ts = $1 WHERE localpart = $2", spec.AsTimestamp(timestamp), localpart) if err != nil { t.Fatalf("unable to update device last seen") } @@ -121,7 +121,7 @@ func mustUserUpdateRegistered( localpart string, timestamp time.Time, ) { - _, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", gomatrixserverlib.AsTimestamp(timestamp), localpart) + _, err := db.ExecContext(ctx, "UPDATE userapi_accounts SET created_ts = $1 WHERE localpart = $2", spec.AsTimestamp(timestamp), localpart) if err != nil { t.Fatalf("unable to update device last seen") } diff --git a/userapi/types/storage.go b/userapi/types/storage.go index a910f7f101..2c918847d3 100644 --- a/userapi/types/storage.go +++ b/userapi/types/storage.go @@ -19,6 +19,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" ) const ( @@ -45,7 +46,7 @@ var KeyTypeIntToPurpose = map[int16]fclient.CrossSigningKeyPurpose{ } // Map of purpose -> public key -type CrossSigningKeyMap map[fclient.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes +type CrossSigningKeyMap map[fclient.CrossSigningKeyPurpose]spec.Base64Bytes // Map of user ID -> key ID -> signature -type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes +type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 9d068ca3b1..45762a7d87 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/nats-io/nats.go" "golang.org/x/crypto/bcrypt" @@ -41,7 +42,7 @@ import ( ) const ( - serverName = gomatrixserverlib.ServerName("example.com") + serverName = spec.ServerName("example.com") ) type apiTestOpts struct { @@ -74,7 +75,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, pub cfg, ctx, close := testrig.CreateConfig(t, dbType) sName := serverName if opts.serverName != "" { - sName = gomatrixserverlib.ServerName(opts.serverName) + sName = spec.ServerName(opts.serverName) } cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions) diff --git a/userapi/util/devices.go b/userapi/util/devices.go index 31617d8c10..117da08ea9 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -7,7 +7,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -19,7 +19,7 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) { +func GetPushDevices(ctx context.Context, localpart string, serverName spec.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) { pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { return nil, fmt.Errorf("db.GetPushers: %w", err) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index 08d1371d61..45d37525c1 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" log "github.com/sirupsen/logrus" ) @@ -17,7 +17,7 @@ import ( // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error { +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName spec.ServerName, db storage.UserDatabase) error { pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index d6cbad7db8..3017069bc5 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -11,6 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "golang.org/x/crypto/bcrypt" @@ -87,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Prepare pusher with our test server URL - if err := db.UpsertPusher(ctx, api.Pusher{ + if err = db.UpsertPusher(ctx, api.Pusher{ Kind: api.HTTPKind, AppID: appID, PushKey: pushKey, @@ -99,8 +100,13 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Insert a dummy event + sender, err := spec.NewUserID(alice.ID, true) + if err != nil { + t.Error(err) + } + sk := "" if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.HeaderedToClientEvent(dummyEvent, synctypes.FormatAll), + Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender, &sk), }); err != nil { t.Error(err) } diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go index 21035e0451..4bf9a5d886 100644 --- a/userapi/util/phonehomestats.go +++ b/userapi/util/phonehomestats.go @@ -24,18 +24,18 @@ import ( "syscall" "time" - "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib/spec" ) type phoneHomeStats struct { prevData timestampToRUUsage stats map[string]interface{} - serverName gomatrixserverlib.ServerName + serverName spec.ServerName startTime time.Time cfg *config.Dendrite db storage.Statistics