Skip to content

Commit

Permalink
Add should package
Browse files Browse the repository at this point in the history
The `should` package is a more lenient form of the `must` package, in
that it does not fail the test if the assertion fails. Instead, these
functions return an error, which the test can then handle.

This is particularly useful in cases where a test needs to retry a
certain request. Previously, you had to hand-roll the checks on the
response because using `must.MatchResponse` would fail the test, even
though it might have been okay as the next retry may pass the matchers.

Fixes #546

Tests need to be revisited to see if this can be used in more places.
  • Loading branch information
kegsay committed Oct 5, 2023
1 parent 9e57f77 commit 607de92
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 159 deletions.
201 changes: 51 additions & 150 deletions must/must.go
Original file line number Diff line number Diff line change
@@ -1,169 +1,107 @@
// Package must contains assertions for tests
// Package must contains assertions for tests, which fail the test if the assertion fails.
package must

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"

"github.com/tidwall/gjson"
"golang.org/x/exp/slices"

"github.com/matrix-org/gomatrixserverlib/fclient"

"github.com/matrix-org/complement/client"
"github.com/matrix-org/complement/match"
"github.com/matrix-org/complement/should"
)

// NotError will ensure `err` is nil else terminate the test with `msg`.
func NotError(t *testing.T, msg string, err error) {
t.Helper()
if err != nil {
t.Fatalf("MustNotError: %s -> %s", msg, err)
t.Fatalf("must.NotError: %s -> %s", msg, err)
}
}

// EXPERIMENTAL
// ParseJSON will ensure that the HTTP request/response body is valid JSON, then return the body, else terminate the test.
func ParseJSON(t *testing.T, b io.ReadCloser) gjson.Result {
t.Helper()
body, err := io.ReadAll(b)
res, err := should.ParseJSON(b)
if err != nil {
t.Fatalf("MustParseJSON: reading body returned %s", err)
t.Fatalf(err.Error())
}
if !gjson.ValidBytes(body) {
t.Fatalf("MustParseJSON: not valid JSON")
}
return gjson.ParseBytes(body)
return res
}

// EXPERIMENTAL
// MatchRequest consumes the HTTP request and performs HTTP-level assertions on it. Returns the raw response body.
func MatchRequest(t *testing.T, req *http.Request, m match.HTTPRequest) []byte {
t.Helper()
body, err := io.ReadAll(req.Body)
res, err := should.MatchRequest(req, m)
if err != nil {
t.Fatalf("MatchRequest: Failed to read request body: %s", err)
}

contextStr := fmt.Sprintf("%s => %s", req.URL.String(), string(body))

if m.Headers != nil {
for name, val := range m.Headers {
if req.Header.Get(name) != val {
t.Fatalf("MatchRequest got %s: %s want %s - %s", name, req.Header.Get(name), val, contextStr)
}
}
}
if m.JSON != nil {
if !gjson.ValidBytes(body) {
t.Fatalf("MatchRequest request body is not valid JSON - %s", contextStr)
}
parsedBody := gjson.ParseBytes(body)
for _, jm := range m.JSON {
if err = jm(parsedBody); err != nil {
t.Fatalf("MatchRequest %s - %s", err, contextStr)
}
}
t.Fatalf(err.Error())
}
return body
return res
}

// EXPERIMENTAL
// MatchSuccess consumes the HTTP response and fails if the response is non-2xx.
func MatchSuccess(t *testing.T, res *http.Response) {
if res.StatusCode < 200 || res.StatusCode > 299 {
t.Fatalf("MatchSuccess got status %d instead of a success code", res.StatusCode)
t.Helper()
if err := should.MatchSuccess(res); err != nil {
t.Fatalf(err.Error())
}
}

// EXPERIMENTAL
// MatchFailure consumes the HTTP response and fails if the response is 2xx.
func MatchFailure(t *testing.T, res *http.Response) {
if res.StatusCode >= 200 && res.StatusCode <= 299 {
t.Fatalf("MatchFailure got status %d instead of a failure code", res.StatusCode)
t.Helper()
if err := should.MatchFailure(res); err != nil {
t.Fatalf(err.Error())
}
}

// EXPERIMENTAL
// MatchResponse consumes the HTTP response and performs HTTP-level assertions on it. Returns the raw response body.
func MatchResponse(t *testing.T, res *http.Response, m match.HTTPResponse) []byte {
t.Helper()
body, err := io.ReadAll(res.Body)
body, err := should.MatchResponse(res, m)
if err != nil {
t.Fatalf("MatchResponse: Failed to read response body: %s", err)
}

contextStr := fmt.Sprintf("%s => %s", res.Request.URL.String(), string(body))

if m.StatusCode != 0 {
if res.StatusCode != m.StatusCode {
t.Fatalf("MatchResponse got status %d want %d - %s", res.StatusCode, m.StatusCode, contextStr)
}
}
if m.Headers != nil {
for name, val := range m.Headers {
if res.Header.Get(name) != val {
t.Fatalf("MatchResponse got %s: %s want %s - %s", name, res.Header.Get(name), val, contextStr)
}
}
}
if m.JSON != nil {
if !gjson.ValidBytes(body) {
t.Fatalf("MatchResponse response body is not valid JSON - %s", contextStr)
}
parsedBody := gjson.ParseBytes(body)
for _, jm := range m.JSON {
if err = jm(parsedBody); err != nil {
t.Fatalf("MatchResponse %s - %s", err, contextStr)
}
}
t.Fatalf(err.Error())
}
return body
}

// MatchFederationRequest performs JSON assertions on incoming federation requests.
func MatchFederationRequest(t *testing.T, fedReq *fclient.FederationRequest, matchers ...match.JSON) {
t.Helper()
content := fedReq.Content()
if !gjson.ValidBytes(content) {
t.Fatalf("MatchFederationRequest content is not valid JSON - %s", fedReq.RequestURI())
}

parsedContent := gjson.ParseBytes(content)
for _, jm := range matchers {
if err := jm(parsedContent); err != nil {
t.Fatalf("MatchFederationRequest %s - %s", err, fedReq.RequestURI())
}
err := should.MatchFederationRequest(fedReq)
if err != nil {
t.Fatalf(err.Error())
}
}

// EXPERIMENTAL
// MatchGJSON performs JSON assertions on a gjson.Result object.
func MatchGJSON(t *testing.T, jsonResult gjson.Result, matchers ...match.JSON) {
t.Helper()

MatchJSONBytes(t, []byte(jsonResult.Raw), matchers...)
err := should.MatchGJSON(jsonResult, matchers...)
if err != nil {
t.Fatalf(err.Error())
}
}

// EXPERIMENTAL
// MatchJSONBytes performs JSON assertions on a raw json byte slice.
func MatchJSONBytes(t *testing.T, rawJson []byte, matchers ...match.JSON) {
t.Helper()

if !gjson.ValidBytes(rawJson) {
t.Fatalf("MatchJSONBytes: rawJson is not valid JSON")
}

body := gjson.ParseBytes(rawJson)
for _, jm := range matchers {
if err := jm(body); err != nil {
t.Fatalf("MatchJSONBytes %s with input = %v", err, string(rawJson))
}
err := should.MatchJSONBytes(rawJson, matchers...)
if err != nil {
t.Fatalf(err.Error())
}
}

Expand Down Expand Up @@ -198,27 +136,20 @@ func StartWithStr(t *testing.T, got, wantPrefix, msg string) {
// The format of `wantKey` is specified at https://godoc.org/github.com/tidwall/gjson#Get
func GetJSONFieldStr(t *testing.T, body gjson.Result, wantKey string) string {
t.Helper()
res := body.Get(wantKey)
if res.Index == 0 {
t.Fatalf("JSONFieldStr: key '%s' missing from %s", wantKey, body.Raw)
}
if res.Str == "" {
t.Fatalf("JSONFieldStr: key '%s' is not a string, body: %s", wantKey, body.Raw)
str, err := should.GetJSONFieldStr(body, wantKey)
if err != nil {
t.Fatalf(err.Error())
}
return res.Str
return str
}

// EXPERIMENTAL
// HaveInOrder checks that the two slices match exactly, failing the test on mismatches or omissions.
func HaveInOrder[V comparable](t *testing.T, gots []V, wants []V) {
t.Helper()
if len(gots) != len(wants) {
t.Fatalf("HaveInOrder: length mismatch, got %v want %v", gots, wants)
}
for i := range gots {
if gots[i] != wants[i] {
t.Errorf("HaveInOrder: index %d got %v want %v", i, gots[i], wants[i])
}
err := should.HaveInOrder(gots, wants)
if err != nil {
t.Fatalf(err.Error())
}
}

Expand All @@ -227,13 +158,9 @@ func HaveInOrder[V comparable](t *testing.T, gots []V, wants []V) {
// in larger. Ignores ordering.
func ContainSubset[V comparable](t *testing.T, larger []V, smaller []V) {
t.Helper()
if len(larger) < len(smaller) {
t.Fatalf("ContainSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller))
}
for i, item := range smaller {
if !slices.Contains(larger, item) {
t.Fatalf("ContainSubset: element not found in larger set: smaller[%d] (%v) larger=%v", i, item, larger)
}
err := should.ContainSubset(larger, smaller)
if err != nil {
t.Fatalf(err.Error())
}
}

Expand All @@ -242,26 +169,10 @@ func ContainSubset[V comparable](t *testing.T, larger []V, smaller []V) {
// in larger. Ignores ordering.
func NotContainSubset[V comparable](t *testing.T, larger []V, smaller []V) {
t.Helper()
if len(larger) < len(smaller) {
t.Fatalf("NotContainSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller))
}
for i, item := range smaller {
if slices.Contains(larger, item) {
t.Fatalf("NotContainSubset: element found in larger set: smaller[%d] (%v)", i, item)
}
}
}

// EXPERIMENTAL
// GetTimelineEventIDs returns the timeline event IDs in the sync response for the given room ID. If the room is missing
// this returns a 0 element slice.
func GetTimelineEventIDs(t *testing.T, topLevelSyncJSON gjson.Result, roomID string) []string {
timeline := topLevelSyncJSON.Get(fmt.Sprintf("rooms.join.%s.timeline.events", client.GjsonEscape(roomID))).Array()
eventIDs := make([]string, len(timeline))
for i := range timeline {
eventIDs[i] = timeline[i].Get("event_id").Str
err := should.NotContainSubset(larger, smaller)
if err != nil {
t.Fatalf(err.Error())
}
return eventIDs
}

// EXPERIMENTAL
Expand All @@ -272,9 +183,9 @@ func GetTimelineEventIDs(t *testing.T, topLevelSyncJSON gjson.Result, roomID str
// Items are compared using match.JSONDeepEqual
func CheckOffAll(t *testing.T, items []interface{}, wantItems []interface{}) {
t.Helper()
remaining := CheckOffAllAllowUnwanted(t, items, wantItems)
if len(remaining) > 0 {
t.Errorf("CheckOffAll: unexpected items %v", remaining)
err := should.CheckOffAll(items, wantItems)
if err != nil {
t.Fatalf(err.Error())
}
}

Expand All @@ -286,10 +197,11 @@ func CheckOffAll(t *testing.T, items []interface{}, wantItems []interface{}) {
// Items are compared using match.JSONDeepEqual
func CheckOffAllAllowUnwanted(t *testing.T, items []interface{}, wantItems []interface{}) []interface{} {
t.Helper()
for _, wantItem := range wantItems {
items = CheckOff(t, items, wantItem)
result, err := should.CheckOffAllAllowUnwanted(items, wantItems)
if err != nil {
t.Fatalf(err.Error())
}
return items
return result
}

// EXPERIMENTAL
Expand All @@ -298,22 +210,11 @@ func CheckOffAllAllowUnwanted(t *testing.T, items []interface{}, wantItems []int
// compared using JSON deep equal.
func CheckOff(t *testing.T, items []interface{}, wantItem interface{}) []interface{} {
t.Helper()
// check off the item
want := -1
for i, w := range items {
wBytes, _ := json.Marshal(w)
if jsonDeepEqual(wBytes, wantItem) {
want = i
break
}
}
if want == -1 {
t.Errorf("CheckOff: item %s not present", wantItem)
return items
result, err := should.CheckOff(items, wantItem)
if err != nil {
t.Fatalf(err.Error())
}
// delete the wanted item
items = append(items[:want], items[want+1:]...)
return items
return result
}

func jsonDeepEqual(gotJson []byte, wantValue interface{}) bool {
Expand Down
Loading

0 comments on commit 607de92

Please sign in to comment.