From 9e0a40a3ed3368310cfaa310e144f50a1a489740 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 4 Oct 2023 18:08:42 +0100 Subject: [PATCH] add more useful must functions --- must/must.go | 50 +++++++++++++++++++++-------- tests/direct_messaging_test.go | 12 ++----- tests/federation_acl_test.go | 23 +++---------- tests/federation_event_auth_test.go | 13 +++----- tests/federation_keys_test.go | 35 +++++++------------- tests/federation_redaction_test.go | 10 ++---- 6 files changed, 61 insertions(+), 82 deletions(-) diff --git a/must/must.go b/must/must.go index 014af7ec..0ea78f6b 100644 --- a/must/must.go +++ b/must/must.go @@ -16,6 +16,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/complement/client" "github.com/matrix-org/complement/match" ) @@ -144,15 +145,7 @@ func MatchFederationRequest(t *testing.T, fedReq *fclient.FederationRequest, mat func MatchGJSON(t *testing.T, jsonResult gjson.Result, matchers ...match.JSON) { t.Helper() - MatchJSON(t, jsonResult.Raw, matchers...) -} - -// EXPERIMENTAL -// MatchJSON performs JSON assertions on a raw JSON string. -func MatchJSON(t *testing.T, json string, matchers ...match.JSON) { - t.Helper() - - MatchJSONBytes(t, []byte(json), matchers...) + MatchJSONBytes(t, []byte(jsonResult.Raw), matchers...) } // EXPERIMENTAL @@ -166,12 +159,13 @@ func MatchJSONBytes(t *testing.T, rawJson []byte, matchers ...match.JSON) { for _, jm := range matchers { if err := jm(rawJson); err != nil { - t.Fatalf("MatchJSONBytes %s", err) + t.Fatalf("MatchJSONBytes %s with input = %v", err, string(rawJson)) } } } // Equal ensures that got==want else logs an error. +// The 'msg' is displayed with the error to provide extra context. func Equal[V comparable](t *testing.T, got, want V, msg string) { t.Helper() if got != want { @@ -180,6 +174,7 @@ func Equal[V comparable](t *testing.T, got, want V, msg string) { } // NotEqualStr ensures that got!=want else logs an error. +// The 'msg' is displayed with the error to provide extra context. func NotEqual[V comparable](t *testing.T, got, want V, msg string) { t.Helper() if got == want { @@ -242,20 +237,47 @@ func HaveInAnyOrder[V constraints.Ordered](t *testing.T, gots []V, wants []V) { } // EXPERIMENTAL -// ContainsSubset checks that every item in smaller is in larger, failing the test if at least 1 item isn't. Ignores extra elements +// ContainSubset checks that every item in smaller is in larger, failing the test if at least 1 item isn't. Ignores extra elements // in larger. Ignores ordering. -func ContainsSubset[V comparable](t *testing.T, larger []V, smaller []V) { +func ContainSubset[V comparable](t *testing.T, larger []V, smaller []V) { t.Helper() if len(larger) < len(smaller) { - t.Fatalf("ContainsSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller)) + t.Fatalf("ContainSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller)) } for i, item := range smaller { if !slices.Contains(larger, item) { - t.Fatalf("ContainsSubset: element not found in larger set: smaller[%d] (%v)", i, item) + t.Fatalf("ContainSubset: element not found in larger set: smaller[%d] (%v) larger=%v", i, item, larger) } } } +// EXPERIMENTAL +// NotContainSubset checks that every item in smaller is NOT in larger, failing the test if at least 1 item is. Ignores extra elements +// in larger. Ignores ordering. +func NotContainSubset[V comparable](t *testing.T, larger []V, smaller []V) { + t.Helper() + if len(larger) < len(smaller) { + t.Fatalf("NotContainSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller)) + } + for i, item := range smaller { + if slices.Contains(larger, item) { + t.Fatalf("NotContainSubset: element found in larger set: smaller[%d] (%v)", i, item) + } + } +} + +// EXPERIMENTAL +// GetTimelineEventIDs returns the timeline event IDs in the sync response for the given room ID. If the room is missing +// this returns a 0 element slice. +func GetTimelineEventIDs(t *testing.T, topLevelSyncJSON gjson.Result, roomID string) []string { + timeline := topLevelSyncJSON.Get(fmt.Sprintf("rooms.join.%s.timeline.events", client.GjsonEscape(roomID))).Array() + eventIDs := make([]string, len(timeline)) + for i := range timeline { + eventIDs[i] = timeline[i].Get("event_id").Str + } + return eventIDs +} + // EXPERIMENTAL // CheckOffAll checks that a list contains exactly the given items, in any order. // diff --git a/tests/direct_messaging_test.go b/tests/direct_messaging_test.go index 8a3a7f0d..2693accb 100644 --- a/tests/direct_messaging_test.go +++ b/tests/direct_messaging_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/matrix-org/complement/client" "github.com/matrix-org/complement/b" + "github.com/matrix-org/complement/client" "github.com/matrix-org/complement/internal/federation" "github.com/matrix-org/complement/match" "github.com/matrix-org/complement/must" @@ -42,15 +42,7 @@ func TestWriteMDirectAccountData(t *testing.T) { if r.Get("type").Str != "m.direct" { return false } - content := r.Get("content") - rooms := content.Get(bob.UserID) - if !rooms.Exists() || !rooms.IsArray() { - t.Errorf("m.direct event missing rooms array for user %s", bob.UserID) - return false - } - if rooms.Array()[0].Str != roomID { - t.Errorf("m.direct room for %s mismatch: got %v want %v", bob.UserID, rooms.Str, roomID) - } + must.MatchGJSON(t, r, match.JSONKeyEqual("content."+client.GjsonEscape(bob.UserID), []string{roomID})) return true } t.Logf("%s: global account data set; syncing until it arrives", time.Now()) // synapse#13334 diff --git a/tests/federation_acl_test.go b/tests/federation_acl_test.go index ef22e685..74f265a3 100644 --- a/tests/federation_acl_test.go +++ b/tests/federation_acl_test.go @@ -1,7 +1,6 @@ package tests import ( - "fmt" "testing" "github.com/matrix-org/complement/b" @@ -120,26 +119,12 @@ func TestACLs(t *testing.T) { syncResp, _ := user.MustSync(t, client.SyncReq{}) // we don't expect eventID (blocked) to be in the sync response - events := syncResp.Get(fmt.Sprintf("rooms.join.%s.timeline.events", client.GjsonEscape(roomID))).Array() - for _, ev := range events { - if ev.Get("event_id").Str == eventID { - t.Fatalf("unexpected eventID from ACLed room: %s", eventID) - } - } + events := must.GetTimelineEventIDs(t, syncResp, roomID) + must.NotContainSubset(t, events, []string{eventID}) // also check that our sentinel event is present - var seenSentinelEvent bool - events = syncResp.Get(fmt.Sprintf("rooms.join.%s.timeline.events", client.GjsonEscape(sentinelRoom))).Array() - for _, ev := range events { - if ev.Get("event_id").Str == sentinelEventID { - seenSentinelEvent = true - break - } - } - - if !seenSentinelEvent { - t.Fatalf("expected to see sentinel event but didn't") - } + events = must.GetTimelineEventIDs(t, syncResp, sentinelRoom) + must.ContainSubset(t, events, []string{sentinelEventID}) // Validate the ACL event is actually in the rooms state res := user.Do(t, "GET", []string{"_matrix", "client", "v3", "rooms", roomID, "state", "m.room.server_acl"}) diff --git a/tests/federation_event_auth_test.go b/tests/federation_event_auth_test.go index 4b224dac..6404cadb 100644 --- a/tests/federation_event_auth_test.go +++ b/tests/federation_event_auth_test.go @@ -92,16 +92,11 @@ func TestEventAuth(t *testing.T) { t.Fatalf("got %d valid auth events (%d total), wanted %d.\n%s\nwant: %s", len(gotAuthEvents), len(eventAuthResp.AuthEvents), len(wantAuthEventIDs), msg, wantAuthEventIDs) } // make sure all the events match - wantIDs := map[string]bool{} - for _, id := range wantAuthEventIDs { - wantIDs[id] = true - } - for _, e := range gotAuthEvents { - delete(wantIDs, e.EventID()) - } - if len(wantIDs) > 0 { - t.Errorf("missing events %v", wantIDs) + gotIDs := make([]string, len(gotAuthEvents)) + for i := range gotIDs { + gotIDs[i] = gotAuthEvents[i].EventID() } + must.ContainSubset(t, gotIDs, wantAuthEventIDs) } t.Run("returns auth events for the requested event", func(t *testing.T) { diff --git a/tests/federation_keys_test.go b/tests/federation_keys_test.go index 1b50a049..b63d3474 100644 --- a/tests/federation_keys_test.go +++ b/tests/federation_keys_test.go @@ -56,13 +56,10 @@ func TestInboundFederationKeys(t *testing.T) { key := v.Get("key") - // Test key existence and string type - if !key.Exists() { - return fmt.Errorf("verify_keys: Key '%s' has no 'key' field", k.Str) - } - if key.Type != gjson.String { - return fmt.Errorf("verify_keys: Key '%s' has 'key' with unexpected type, expected String, got '%s'", k.Str, key.Type.String()) - } + must.MatchGJSON(t, v, + match.JSONKeyPresent("key"), + match.JSONKeyTypeEqual("key", gjson.String), + ) var keyBytes []byte keyBytes, err = base64.RawStdEncoding.DecodeString(key.Str) @@ -79,26 +76,18 @@ func TestInboundFederationKeys(t *testing.T) { return fmt.Errorf("old_verify_keys: Key '%s' has no 'ed25519:' prefix", k.Str) } - expiredTs := v.Get("expired_ts") + must.MatchGJSON(t, v, + match.JSONKeyPresent("expired_ts"), + match.JSONKeyTypeEqual("expired_ts", gjson.Number), + ) - // Test expired_ts existence and number type - if !expiredTs.Exists() { - return fmt.Errorf("old_verify_keys: Key '%s' has no 'expired_ts' field", k.Str) - } - if expiredTs.Type != gjson.Number { - return fmt.Errorf("old_verify_keys: Key '%s' has expired_ts with unexpected type, expected Number, got '%s'", k.Str, expiredTs.Type.String()) - } + must.MatchGJSON(t, v, + match.JSONKeyPresent("key"), + match.JSONKeyTypeEqual("key", gjson.String), + ) key := v.Get("key") - // Test key existence and string type - if !key.Exists() { - return fmt.Errorf("old_verify_keys: Key '%s' has no 'key' field", k.Str) - } - if key.Type != gjson.String { - return fmt.Errorf("old_verify_keys: Key '%s' has 'key' with unexpected type, expected String, got '%s'", k.Str, key.Type.String()) - } - var keyBytes []byte keyBytes, err = base64.RawStdEncoding.DecodeString(key.Str) if err != nil { diff --git a/tests/federation_redaction_test.go b/tests/federation_redaction_test.go index 5531e3f0..00a86c87 100644 --- a/tests/federation_redaction_test.go +++ b/tests/federation_redaction_test.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/complement/b" "github.com/matrix-org/complement/internal/federation" + "github.com/matrix-org/complement/must" "github.com/matrix-org/complement/runtime" "github.com/matrix-org/gomatrixserverlib" ) @@ -31,9 +32,7 @@ func TestFederationRedactSendsWithoutEvent(t *testing.T) { func(ev gomatrixserverlib.PDU) { defer waiter.Finish() - if ev.Type() != wantEventType { - t.Errorf("Wrong event type, got %s want %s", ev.Type(), wantEventType) - } + must.Equal(t, ev.Type(), wantEventType, "wrong event type") }, nil, ), @@ -89,8 +88,5 @@ func TestFederationRedactSendsWithoutEvent(t *testing.T) { } // check that the event id of the redaction sent by alice is the same as the redaction event in the room - if res != lastEvent.EventID() { - t.Fatalf("Incorrent event id %s, wanted %s.", res, lastEvent.EventID()) - } - + must.Equal(t, lastEvent.EventID(), res, "incorrect event id") }