From 1bb323418ca9b0b93153d1dfdfdda74461947b57 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 9 Oct 2023 10:34:43 +0000 Subject: [PATCH 1/2] Send `Content-Type: application/json` for 'instructions' (#650) --- internal/instruction/runner.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/internal/instruction/runner.go b/internal/instruction/runner.go index 0e31f745..97394b96 100644 --- a/internal/instruction/runner.go +++ b/internal/instruction/runner.go @@ -276,6 +276,11 @@ func (r *Runner) next(instrs []instruction, hsURL string, i int) (*http.Request, return nil, nil, 0 } + if body != nil { + // all bodies, if set, are JSON encoded + req.Header["Content-Type"] = []string{"application/json"} + } + q := req.URL.Query() if instr.accessToken != "" { at, ok := r.lookup.Load(instr.accessToken) From 00b332647aa5f295f7fcacc669807485d0a2cf50 Mon Sep 17 00:00:00 2001 From: kegsay Date: Mon, 9 Oct 2023 14:01:06 +0100 Subject: [PATCH 2/2] Add should package (#663) The `should` package is a more lenient form of the `must` package, in that it does not fail the test if the assertion fails. Instead, these functions return an error, which the test can then handle. This is particularly useful in cases where a test needs to retry a certain request. Previously, you had to hand-roll the checks on the response because using `must.MatchResponse` would fail the test, even though it might have been okay as the next retry may pass the matchers. Fixes https://github.com/matrix-org/complement/issues/546 Tests need to be revisited to see if this can be used in more places. --- must/must.go | 201 +++++------------- should/should.go | 290 ++++++++++++++++++++++++++ tests/csapi/apidoc_room_alias_test.go | 14 +- tests/federation_acl_test.go | 5 +- 4 files changed, 351 insertions(+), 159 deletions(-) create mode 100644 should/should.go diff --git a/must/must.go b/must/must.go index 5ba0a84a..f54f3745 100644 --- a/must/must.go +++ b/must/must.go @@ -1,29 +1,27 @@ -// Package must contains assertions for tests +// Package must contains assertions for tests, which fail the test if the assertion fails. package must import ( "bytes" "encoding/json" - "fmt" "io" "net/http" "strings" "testing" "github.com/tidwall/gjson" - "golang.org/x/exp/slices" "github.com/matrix-org/gomatrixserverlib/fclient" - "github.com/matrix-org/complement/client" "github.com/matrix-org/complement/match" + "github.com/matrix-org/complement/should" ) // NotError will ensure `err` is nil else terminate the test with `msg`. func NotError(t *testing.T, msg string, err error) { t.Helper() if err != nil { - t.Fatalf("MustNotError: %s -> %s", msg, err) + t.Fatalf("must.NotError: %s -> %s", msg, err) } } @@ -31,61 +29,39 @@ func NotError(t *testing.T, msg string, err error) { // ParseJSON will ensure that the HTTP request/response body is valid JSON, then return the body, else terminate the test. func ParseJSON(t *testing.T, b io.ReadCloser) gjson.Result { t.Helper() - body, err := io.ReadAll(b) + res, err := should.ParseJSON(b) if err != nil { - t.Fatalf("MustParseJSON: reading body returned %s", err) + t.Fatalf(err.Error()) } - if !gjson.ValidBytes(body) { - t.Fatalf("MustParseJSON: not valid JSON") - } - return gjson.ParseBytes(body) + return res } // EXPERIMENTAL // MatchRequest consumes the HTTP request and performs HTTP-level assertions on it. Returns the raw response body. func MatchRequest(t *testing.T, req *http.Request, m match.HTTPRequest) []byte { t.Helper() - body, err := io.ReadAll(req.Body) + res, err := should.MatchRequest(req, m) if err != nil { - t.Fatalf("MatchRequest: Failed to read request body: %s", err) - } - - contextStr := fmt.Sprintf("%s => %s", req.URL.String(), string(body)) - - if m.Headers != nil { - for name, val := range m.Headers { - if req.Header.Get(name) != val { - t.Fatalf("MatchRequest got %s: %s want %s - %s", name, req.Header.Get(name), val, contextStr) - } - } - } - if m.JSON != nil { - if !gjson.ValidBytes(body) { - t.Fatalf("MatchRequest request body is not valid JSON - %s", contextStr) - } - parsedBody := gjson.ParseBytes(body) - for _, jm := range m.JSON { - if err = jm(parsedBody); err != nil { - t.Fatalf("MatchRequest %s - %s", err, contextStr) - } - } + t.Fatalf(err.Error()) } - return body + return res } // EXPERIMENTAL // MatchSuccess consumes the HTTP response and fails if the response is non-2xx. func MatchSuccess(t *testing.T, res *http.Response) { - if res.StatusCode < 200 || res.StatusCode > 299 { - t.Fatalf("MatchSuccess got status %d instead of a success code", res.StatusCode) + t.Helper() + if err := should.MatchSuccess(res); err != nil { + t.Fatalf(err.Error()) } } // EXPERIMENTAL // MatchFailure consumes the HTTP response and fails if the response is 2xx. func MatchFailure(t *testing.T, res *http.Response) { - if res.StatusCode >= 200 && res.StatusCode <= 299 { - t.Fatalf("MatchFailure got status %d instead of a failure code", res.StatusCode) + t.Helper() + if err := should.MatchFailure(res); err != nil { + t.Fatalf(err.Error()) } } @@ -93,35 +69,9 @@ func MatchFailure(t *testing.T, res *http.Response) { // MatchResponse consumes the HTTP response and performs HTTP-level assertions on it. Returns the raw response body. func MatchResponse(t *testing.T, res *http.Response, m match.HTTPResponse) []byte { t.Helper() - body, err := io.ReadAll(res.Body) + body, err := should.MatchResponse(res, m) if err != nil { - t.Fatalf("MatchResponse: Failed to read response body: %s", err) - } - - contextStr := fmt.Sprintf("%s => %s", res.Request.URL.String(), string(body)) - - if m.StatusCode != 0 { - if res.StatusCode != m.StatusCode { - t.Fatalf("MatchResponse got status %d want %d - %s", res.StatusCode, m.StatusCode, contextStr) - } - } - if m.Headers != nil { - for name, val := range m.Headers { - if res.Header.Get(name) != val { - t.Fatalf("MatchResponse got %s: %s want %s - %s", name, res.Header.Get(name), val, contextStr) - } - } - } - if m.JSON != nil { - if !gjson.ValidBytes(body) { - t.Fatalf("MatchResponse response body is not valid JSON - %s", contextStr) - } - parsedBody := gjson.ParseBytes(body) - for _, jm := range m.JSON { - if err = jm(parsedBody); err != nil { - t.Fatalf("MatchResponse %s - %s", err, contextStr) - } - } + t.Fatalf(err.Error()) } return body } @@ -129,16 +79,9 @@ func MatchResponse(t *testing.T, res *http.Response, m match.HTTPResponse) []byt // MatchFederationRequest performs JSON assertions on incoming federation requests. func MatchFederationRequest(t *testing.T, fedReq *fclient.FederationRequest, matchers ...match.JSON) { t.Helper() - content := fedReq.Content() - if !gjson.ValidBytes(content) { - t.Fatalf("MatchFederationRequest content is not valid JSON - %s", fedReq.RequestURI()) - } - - parsedContent := gjson.ParseBytes(content) - for _, jm := range matchers { - if err := jm(parsedContent); err != nil { - t.Fatalf("MatchFederationRequest %s - %s", err, fedReq.RequestURI()) - } + err := should.MatchFederationRequest(fedReq) + if err != nil { + t.Fatalf(err.Error()) } } @@ -146,24 +89,19 @@ func MatchFederationRequest(t *testing.T, fedReq *fclient.FederationRequest, mat // MatchGJSON performs JSON assertions on a gjson.Result object. func MatchGJSON(t *testing.T, jsonResult gjson.Result, matchers ...match.JSON) { t.Helper() - - MatchJSONBytes(t, []byte(jsonResult.Raw), matchers...) + err := should.MatchGJSON(jsonResult, matchers...) + if err != nil { + t.Fatalf(err.Error()) + } } // EXPERIMENTAL // MatchJSONBytes performs JSON assertions on a raw json byte slice. func MatchJSONBytes(t *testing.T, rawJson []byte, matchers ...match.JSON) { t.Helper() - - if !gjson.ValidBytes(rawJson) { - t.Fatalf("MatchJSONBytes: rawJson is not valid JSON") - } - - body := gjson.ParseBytes(rawJson) - for _, jm := range matchers { - if err := jm(body); err != nil { - t.Fatalf("MatchJSONBytes %s with input = %v", err, string(rawJson)) - } + err := should.MatchJSONBytes(rawJson, matchers...) + if err != nil { + t.Fatalf(err.Error()) } } @@ -198,27 +136,20 @@ func StartWithStr(t *testing.T, got, wantPrefix, msg string) { // The format of `wantKey` is specified at https://godoc.org/github.com/tidwall/gjson#Get func GetJSONFieldStr(t *testing.T, body gjson.Result, wantKey string) string { t.Helper() - res := body.Get(wantKey) - if res.Index == 0 { - t.Fatalf("JSONFieldStr: key '%s' missing from %s", wantKey, body.Raw) - } - if res.Str == "" { - t.Fatalf("JSONFieldStr: key '%s' is not a string, body: %s", wantKey, body.Raw) + str, err := should.GetJSONFieldStr(body, wantKey) + if err != nil { + t.Fatalf(err.Error()) } - return res.Str + return str } // EXPERIMENTAL // HaveInOrder checks that the two slices match exactly, failing the test on mismatches or omissions. func HaveInOrder[V comparable](t *testing.T, gots []V, wants []V) { t.Helper() - if len(gots) != len(wants) { - t.Fatalf("HaveInOrder: length mismatch, got %v want %v", gots, wants) - } - for i := range gots { - if gots[i] != wants[i] { - t.Errorf("HaveInOrder: index %d got %v want %v", i, gots[i], wants[i]) - } + err := should.HaveInOrder(gots, wants) + if err != nil { + t.Fatalf(err.Error()) } } @@ -227,13 +158,9 @@ func HaveInOrder[V comparable](t *testing.T, gots []V, wants []V) { // in larger. Ignores ordering. func ContainSubset[V comparable](t *testing.T, larger []V, smaller []V) { t.Helper() - if len(larger) < len(smaller) { - t.Fatalf("ContainSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller)) - } - for i, item := range smaller { - if !slices.Contains(larger, item) { - t.Fatalf("ContainSubset: element not found in larger set: smaller[%d] (%v) larger=%v", i, item, larger) - } + err := should.ContainSubset(larger, smaller) + if err != nil { + t.Fatalf(err.Error()) } } @@ -242,26 +169,10 @@ func ContainSubset[V comparable](t *testing.T, larger []V, smaller []V) { // in larger. Ignores ordering. func NotContainSubset[V comparable](t *testing.T, larger []V, smaller []V) { t.Helper() - if len(larger) < len(smaller) { - t.Fatalf("NotContainSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller)) - } - for i, item := range smaller { - if slices.Contains(larger, item) { - t.Fatalf("NotContainSubset: element found in larger set: smaller[%d] (%v)", i, item) - } - } -} - -// EXPERIMENTAL -// GetTimelineEventIDs returns the timeline event IDs in the sync response for the given room ID. If the room is missing -// this returns a 0 element slice. -func GetTimelineEventIDs(t *testing.T, topLevelSyncJSON gjson.Result, roomID string) []string { - timeline := topLevelSyncJSON.Get(fmt.Sprintf("rooms.join.%s.timeline.events", client.GjsonEscape(roomID))).Array() - eventIDs := make([]string, len(timeline)) - for i := range timeline { - eventIDs[i] = timeline[i].Get("event_id").Str + err := should.NotContainSubset(larger, smaller) + if err != nil { + t.Fatalf(err.Error()) } - return eventIDs } // EXPERIMENTAL @@ -272,9 +183,9 @@ func GetTimelineEventIDs(t *testing.T, topLevelSyncJSON gjson.Result, roomID str // Items are compared using match.JSONDeepEqual func CheckOffAll(t *testing.T, items []interface{}, wantItems []interface{}) { t.Helper() - remaining := CheckOffAllAllowUnwanted(t, items, wantItems) - if len(remaining) > 0 { - t.Errorf("CheckOffAll: unexpected items %v", remaining) + err := should.CheckOffAll(items, wantItems) + if err != nil { + t.Fatalf(err.Error()) } } @@ -286,10 +197,11 @@ func CheckOffAll(t *testing.T, items []interface{}, wantItems []interface{}) { // Items are compared using match.JSONDeepEqual func CheckOffAllAllowUnwanted(t *testing.T, items []interface{}, wantItems []interface{}) []interface{} { t.Helper() - for _, wantItem := range wantItems { - items = CheckOff(t, items, wantItem) + result, err := should.CheckOffAllAllowUnwanted(items, wantItems) + if err != nil { + t.Fatalf(err.Error()) } - return items + return result } // EXPERIMENTAL @@ -298,22 +210,11 @@ func CheckOffAllAllowUnwanted(t *testing.T, items []interface{}, wantItems []int // compared using JSON deep equal. func CheckOff(t *testing.T, items []interface{}, wantItem interface{}) []interface{} { t.Helper() - // check off the item - want := -1 - for i, w := range items { - wBytes, _ := json.Marshal(w) - if jsonDeepEqual(wBytes, wantItem) { - want = i - break - } - } - if want == -1 { - t.Errorf("CheckOff: item %s not present", wantItem) - return items + result, err := should.CheckOff(items, wantItem) + if err != nil { + t.Fatalf(err.Error()) } - // delete the wanted item - items = append(items[:want], items[want+1:]...) - return items + return result } func jsonDeepEqual(gotJson []byte, wantValue interface{}) bool { diff --git a/should/should.go b/should/should.go new file mode 100644 index 00000000..1aa8dfbf --- /dev/null +++ b/should/should.go @@ -0,0 +1,290 @@ +// Package should contains assertions for tests, which returns an error if the assertion fails. +package should + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/tidwall/gjson" + "golang.org/x/exp/slices" + + "github.com/matrix-org/gomatrixserverlib/fclient" + + "github.com/matrix-org/complement/client" + "github.com/matrix-org/complement/match" +) + +// EXPERIMENTAL +// ParseJSON will ensure that the HTTP request/response body is valid JSON, then return the body, else returns an error. +func ParseJSON(b io.ReadCloser) (res gjson.Result, err error) { + body, err := io.ReadAll(b) + if err != nil { + return res, fmt.Errorf("ParseJSON: reading body returned %s", err) + } + if !gjson.ValidBytes(body) { + return res, fmt.Errorf("ParseJSON: not valid JSON") + } + return gjson.ParseBytes(body), nil +} + +// EXPERIMENTAL +// MatchRequest consumes the HTTP request and performs HTTP-level assertions on it. Returns the raw response body. +func MatchRequest(req *http.Request, m match.HTTPRequest) ([]byte, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("MatchRequest: Failed to read request body: %s", err) + } + + contextStr := fmt.Sprintf("%s => %s", req.URL.String(), string(body)) + + if m.Headers != nil { + for name, val := range m.Headers { + if req.Header.Get(name) != val { + return nil, fmt.Errorf("MatchRequest got %s: %s want %s - %s", name, req.Header.Get(name), val, contextStr) + } + } + } + if m.JSON != nil { + if !gjson.ValidBytes(body) { + return nil, fmt.Errorf("MatchRequest request body is not valid JSON - %s", contextStr) + } + parsedBody := gjson.ParseBytes(body) + for _, jm := range m.JSON { + if err = jm(parsedBody); err != nil { + return nil, fmt.Errorf("MatchRequest %s - %s", err, contextStr) + } + } + } + return body, nil +} + +// EXPERIMENTAL +// MatchSuccess consumes the HTTP response and fails if the response is non-2xx. +func MatchSuccess(res *http.Response) error { + if res.StatusCode < 200 || res.StatusCode > 299 { + return fmt.Errorf("MatchSuccess got status %d instead of a success code", res.StatusCode) + } + return nil +} + +// EXPERIMENTAL +// MatchFailure consumes the HTTP response and fails if the response is 2xx. +func MatchFailure(res *http.Response) error { + if res.StatusCode >= 200 && res.StatusCode <= 299 { + return fmt.Errorf("MatchFailure got status %d instead of a failure code", res.StatusCode) + } + return nil +} + +// EXPERIMENTAL +// MatchResponse consumes the HTTP response and performs HTTP-level assertions on it. Returns the raw response body. +func MatchResponse(res *http.Response, m match.HTTPResponse) ([]byte, error) { + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("MatchResponse: Failed to read response body: %s", err) + } + + contextStr := fmt.Sprintf("%s => %s", res.Request.URL.String(), string(body)) + + if m.StatusCode != 0 { + if res.StatusCode != m.StatusCode { + return nil, fmt.Errorf("MatchResponse got status %d want %d - %s", res.StatusCode, m.StatusCode, contextStr) + } + } + if m.Headers != nil { + for name, val := range m.Headers { + if res.Header.Get(name) != val { + return nil, fmt.Errorf("MatchResponse got %s: %s want %s - %s", name, res.Header.Get(name), val, contextStr) + } + } + } + if m.JSON != nil { + if !gjson.ValidBytes(body) { + return nil, fmt.Errorf("MatchResponse response body is not valid JSON - %s", contextStr) + } + parsedBody := gjson.ParseBytes(body) + for _, jm := range m.JSON { + if err = jm(parsedBody); err != nil { + return nil, fmt.Errorf("MatchResponse %s - %s", err, contextStr) + } + } + } + return body, nil +} + +// MatchFederationRequest performs JSON assertions on incoming federation requests. +func MatchFederationRequest(fedReq *fclient.FederationRequest, matchers ...match.JSON) error { + content := fedReq.Content() + if !gjson.ValidBytes(content) { + return fmt.Errorf("MatchFederationRequest content is not valid JSON - %s", fedReq.RequestURI()) + } + + parsedContent := gjson.ParseBytes(content) + for _, jm := range matchers { + if err := jm(parsedContent); err != nil { + return fmt.Errorf("MatchFederationRequest %s - %s", err, fedReq.RequestURI()) + } + } + return nil +} + +// EXPERIMENTAL +// MatchGJSON performs JSON assertions on a gjson.Result object. +func MatchGJSON(jsonResult gjson.Result, matchers ...match.JSON) error { + return MatchJSONBytes([]byte(jsonResult.Raw), matchers...) +} + +// EXPERIMENTAL +// MatchJSONBytes performs JSON assertions on a raw json byte slice. +func MatchJSONBytes(rawJson []byte, matchers ...match.JSON) error { + if !gjson.ValidBytes(rawJson) { + return fmt.Errorf("MatchJSONBytes: rawJson is not valid JSON") + } + + body := gjson.ParseBytes(rawJson) + for _, jm := range matchers { + if err := jm(body); err != nil { + return fmt.Errorf("MatchJSONBytes %s with input = %v", err, string(rawJson)) + } + } + return nil +} + +// EXPERIMENTAL +// GetJSONFieldStr extracts the string value under `wantKey` or fails the test. +// The format of `wantKey` is specified at https://godoc.org/github.com/tidwall/gjson#Get +func GetJSONFieldStr(body gjson.Result, wantKey string) (string, error) { + res := body.Get(wantKey) + if res.Index == 0 { + return "", fmt.Errorf("JSONFieldStr: key '%s' missing from %s", wantKey, body.Raw) + } + if res.Str == "" { + return "", fmt.Errorf("JSONFieldStr: key '%s' is not a string, body: %s", wantKey, body.Raw) + } + return res.Str, nil +} + +// EXPERIMENTAL +// HaveInOrder checks that the two slices match exactly, failing the test on mismatches or omissions. +func HaveInOrder[V comparable](gots []V, wants []V) error { + if len(gots) != len(wants) { + return fmt.Errorf("HaveInOrder: length mismatch, got %v want %v", gots, wants) + } + for i := range gots { + if gots[i] != wants[i] { + return fmt.Errorf("HaveInOrder: index %d got %v want %v", i, gots[i], wants[i]) + } + } + return nil +} + +// EXPERIMENTAL +// ContainSubset checks that every item in smaller is in larger, failing the test if at least 1 item isn't. Ignores extra elements +// in larger. Ignores ordering. +func ContainSubset[V comparable](larger []V, smaller []V) error { + if len(larger) < len(smaller) { + return fmt.Errorf("ContainSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller)) + } + for i, item := range smaller { + if !slices.Contains(larger, item) { + return fmt.Errorf("ContainSubset: element not found in larger set: smaller[%d] (%v) larger=%v", i, item, larger) + } + } + return nil +} + +// EXPERIMENTAL +// NotContainSubset checks that every item in smaller is NOT in larger, failing the test if at least 1 item is. Ignores extra elements +// in larger. Ignores ordering. +func NotContainSubset[V comparable](larger []V, smaller []V) error { + if len(larger) < len(smaller) { + return fmt.Errorf("NotContainSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller)) + } + for i, item := range smaller { + if slices.Contains(larger, item) { + return fmt.Errorf("NotContainSubset: element found in larger set: smaller[%d] (%v)", i, item) + } + } + return nil +} + +// EXPERIMENTAL +// GetTimelineEventIDs returns the timeline event IDs in the sync response for the given room ID. If the room is missing +// this returns a 0 element slice. +func GetTimelineEventIDs(topLevelSyncJSON gjson.Result, roomID string) []string { + timeline := topLevelSyncJSON.Get(fmt.Sprintf("rooms.join.%s.timeline.events", client.GjsonEscape(roomID))).Array() + eventIDs := make([]string, len(timeline)) + for i := range timeline { + eventIDs[i] = timeline[i].Get("event_id").Str + } + return eventIDs +} + +// EXPERIMENTAL +// CheckOffAll checks that a list contains exactly the given items, in any order. +// +// if an item is not present, an error is returned. +// if an item not present in the want list is present, an error is returned. +// Items are compared using match.JSONDeepEqual +func CheckOffAll(items []interface{}, wantItems []interface{}) error { + remaining, err := CheckOffAllAllowUnwanted(items, wantItems) + if err != nil { + return err + } + if len(remaining) > 0 { + return fmt.Errorf("CheckOffAll: unexpected items %v", remaining) + } + return nil +} + +// EXPERIMENTAL +// CheckOffAllAllowUnwanted checks that a list contains all of the given items, in any order. +// The updated list with the matched items removed from it is returned. +// +// if an item is not present, an error is returned +// Items are compared using match.JSONDeepEqual +func CheckOffAllAllowUnwanted(items []interface{}, wantItems []interface{}) ([]interface{}, error) { + var err error + for _, wantItem := range wantItems { + items, err = CheckOff(items, wantItem) + if err != nil { + return nil, err + } + } + return items, nil +} + +// EXPERIMENTAL +// CheckOff an item from the list. If the item is not present an error is returned +// The updated list with the matched item removed from it is returned. Items are +// compared using JSON deep equal. +func CheckOff(items []interface{}, wantItem interface{}) ([]interface{}, error) { + // check off the item + want := -1 + for i, w := range items { + wBytes, _ := json.Marshal(w) + if jsonDeepEqual(wBytes, wantItem) { + want = i + break + } + } + if want == -1 { + return nil, fmt.Errorf("CheckOff: item %s not present", wantItem) + } + // delete the wanted item + items = append(items[:want], items[want+1:]...) + return items, nil +} + +func jsonDeepEqual(gotJson []byte, wantValue interface{}) bool { + // marshal what the test gave us + wantBytes, _ := json.Marshal(wantValue) + // re-marshal what the network gave us to acount for key ordering + var gotVal interface{} + _ = json.Unmarshal(gotJson, &gotVal) + gotBytes, _ := json.Marshal(gotVal) + return bytes.Equal(gotBytes, wantBytes) +} diff --git a/tests/csapi/apidoc_room_alias_test.go b/tests/csapi/apidoc_room_alias_test.go index dc134ee5..347b4d56 100644 --- a/tests/csapi/apidoc_room_alias_test.go +++ b/tests/csapi/apidoc_room_alias_test.go @@ -11,6 +11,7 @@ import ( "github.com/matrix-org/complement/client" "github.com/matrix-org/complement/match" "github.com/matrix-org/complement/must" + "github.com/matrix-org/complement/should" ) func setRoomAliasResp(t *testing.T, c *client.CSAPI, roomID, roomAlias string) *http.Response { @@ -116,13 +117,12 @@ func TestRoomAlias(t *testing.T) { client.WithRetryUntil( 1*time.Second, func(res *http.Response) bool { - if res.StatusCode != 200 { - return false - } - eventResBody := client.ParseJSON(t, res) - parsedEventResBody := gjson.ParseBytes(eventResBody) - matcher := match.JSONKeyEqual("aliases", []interface{}{roomAlias}) - err := matcher(parsedEventResBody) + _, err := should.MatchResponse(res, match.HTTPResponse{ + StatusCode: 200, + JSON: []match.JSON{ + match.JSONKeyEqual("aliases", []interface{}{roomAlias}), + }, + }) if err != nil { t.Log(err) return false diff --git a/tests/federation_acl_test.go b/tests/federation_acl_test.go index 74f265a3..398285e0 100644 --- a/tests/federation_acl_test.go +++ b/tests/federation_acl_test.go @@ -8,6 +8,7 @@ import ( "github.com/matrix-org/complement/match" "github.com/matrix-org/complement/must" "github.com/matrix-org/complement/runtime" + "github.com/matrix-org/complement/should" ) // Test for https://github.com/matrix-org/dendrite/issues/3004 @@ -119,11 +120,11 @@ func TestACLs(t *testing.T) { syncResp, _ := user.MustSync(t, client.SyncReq{}) // we don't expect eventID (blocked) to be in the sync response - events := must.GetTimelineEventIDs(t, syncResp, roomID) + events := should.GetTimelineEventIDs(syncResp, roomID) must.NotContainSubset(t, events, []string{eventID}) // also check that our sentinel event is present - events = must.GetTimelineEventIDs(t, syncResp, sentinelRoom) + events = should.GetTimelineEventIDs(syncResp, sentinelRoom) must.ContainSubset(t, events, []string{sentinelEventID}) // Validate the ACL event is actually in the rooms state