Skip to content

Commit

Permalink
Add TestLike interface for using the client without a testing.T
Browse files Browse the repository at this point in the history
Also fix bad merge conflict
  • Loading branch information
kegsay committed Oct 3, 2023
1 parent 7794642 commit b8ef7fe
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 42 deletions.
16 changes: 7 additions & 9 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import (
"crypto/sha1"
"encoding/hex"
"io"
"testing"

"github.com/matrix-org/complement/internal/must"
"github.com/tidwall/gjson"
)

Expand All @@ -20,7 +18,7 @@ func WithDeviceID(deviceID string) LoginOpt {
}

// LoginUser will log in to a homeserver and create a new device on an existing user.
func (c *CSAPI) LoginUser(t *testing.T, localpart, password string, opts ...LoginOpt) (userID, accessToken, deviceID string) {
func (c *CSAPI) LoginUser(t TestLike, localpart, password string, opts ...LoginOpt) (userID, accessToken, deviceID string) {
t.Helper()
reqBody := map[string]interface{}{
"identifier": map[string]interface{}{
Expand Down Expand Up @@ -50,7 +48,7 @@ func (c *CSAPI) LoginUser(t *testing.T, localpart, password string, opts ...Logi

// LoginUserWithRefreshToken will log in to a homeserver, with refresh token enabled,
// and create a new device on an existing user.
func (c *CSAPI) LoginUserWithRefreshToken(t *testing.T, localpart, password string) (userID, accessToken, refreshToken, deviceID string, expiresInMs int64) {
func (c *CSAPI) LoginUserWithRefreshToken(t TestLike, localpart, password string) (userID, accessToken, refreshToken, deviceID string, expiresInMs int64) {
t.Helper()
reqBody := map[string]interface{}{
"identifier": map[string]interface{}{
Expand All @@ -77,7 +75,7 @@ func (c *CSAPI) LoginUserWithRefreshToken(t *testing.T, localpart, password stri
}

// RefreshToken will consume a refresh token and return a new access token and refresh token.
func (c *CSAPI) ConsumeRefreshToken(t *testing.T, refreshToken string) (newAccessToken, newRefreshToken string, expiresInMs int64) {
func (c *CSAPI) ConsumeRefreshToken(t TestLike, refreshToken string) (newAccessToken, newRefreshToken string, expiresInMs int64) {
t.Helper()
reqBody := map[string]interface{}{
"refresh_token": refreshToken,
Expand All @@ -97,7 +95,7 @@ func (c *CSAPI) ConsumeRefreshToken(t *testing.T, refreshToken string) (newAcces

// RegisterUser will register the user with given parameters and
// return user ID, access token and device ID. It fails the test on network error.
func (c *CSAPI) RegisterUser(t *testing.T, localpart, password string) (userID, accessToken, deviceID string) {
func (c *CSAPI) RegisterUser(t TestLike, localpart, password string) (userID, accessToken, deviceID string) {
t.Helper()
reqBody := map[string]interface{}{
"auth": map[string]string{
Expand All @@ -121,13 +119,13 @@ func (c *CSAPI) RegisterUser(t *testing.T, localpart, password string) (userID,

// RegisterSharedSecret registers a new account with a shared secret via HMAC
// See https://github.com/matrix-org/synapse/blob/e550ab17adc8dd3c48daf7fedcd09418a73f524b/synapse/_scripts/register_new_matrix_user.py#L40
func (c *CSAPI) RegisterSharedSecret(t *testing.T, user, pass string, isAdmin bool) (userID, accessToken, deviceID string) {
func (c *CSAPI) RegisterSharedSecret(t TestLike, user, pass string, isAdmin bool) (userID, accessToken, deviceID string) {
resp := c.Do(t, "GET", []string{"_synapse", "admin", "v1", "register"})
if resp.StatusCode != 200 {
t.Skipf("Homeserver image does not support shared secret registration, /_synapse/admin/v1/register returned HTTP %d", resp.StatusCode)
return
}
body := must.ParseJSON(t, resp.Body)
body := ParseJSON(t, resp)
nonce := gjson.GetBytes(body, "nonce")
if !nonce.Exists() {
t.Fatalf("Malformed shared secret GET response: %s", string(body))
Expand All @@ -153,7 +151,7 @@ func (c *CSAPI) RegisterSharedSecret(t *testing.T, user, pass string, isAdmin bo
"admin": isAdmin,
}
resp = c.MustDo(t, "POST", []string{"_synapse", "admin", "v1", "register"}, WithJSONBody(t, reqBody))
body = must.ParseJSON(t, resp.Body)
body = ParseJSON(t, resp)
userID = gjson.GetBytes(body, "user_id").Str
accessToken = gjson.GetBytes(body, "access_token").Str
deviceID = gjson.GetBytes(body, "device_id").Str
Expand Down
69 changes: 40 additions & 29 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"strconv"
"strings"
"sync/atomic"
"testing"
"time"

"github.com/matrix-org/gomatrixserverlib"
Expand All @@ -25,6 +24,18 @@ const (
SharedSecret = "complement"
)

// TestLike is an interface that testing.T satisfies. All client functions accept a TestLike interface,
// with the intention of a `testing.T` being passed into them. However, the client may be used in non-test
// scenarios e.g benchmarks, which can then use the same client by just implementing this interface.
type TestLike interface {
Helper()
Logf(msg string, args ...interface{})
Skipf(msg string, args ...interface{})
Error(args ...interface{})
Errorf(msg string, args ...interface{})
Fatalf(msg string, args ...interface{})
}

type CtxKey string

const (
Expand Down Expand Up @@ -55,7 +66,7 @@ type CSAPI struct {
}

// UploadContent uploads the provided content with an optional file name. Fails the test on error. Returns the MXC URI.
func (c *CSAPI) UploadContent(t *testing.T, fileBody []byte, fileName string, contentType string) string {
func (c *CSAPI) UploadContent(t TestLike, fileBody []byte, fileName string, contentType string) string {
t.Helper()
query := url.Values{}
if fileName != "" {
Expand All @@ -70,7 +81,7 @@ func (c *CSAPI) UploadContent(t *testing.T, fileBody []byte, fileName string, co
}

// DownloadContent downloads media from the server, returning the raw bytes and the Content-Type. Fails the test on error.
func (c *CSAPI) DownloadContent(t *testing.T, mxcUri string) ([]byte, string) {
func (c *CSAPI) DownloadContent(t TestLike, mxcUri string) ([]byte, string) {
t.Helper()
origin, mediaId := SplitMxc(mxcUri)
res := c.MustDo(t, "GET", []string{"_matrix", "media", "v3", "download", origin, mediaId})
Expand All @@ -83,15 +94,15 @@ func (c *CSAPI) DownloadContent(t *testing.T, mxcUri string) ([]byte, string) {
}

// CreateRoom creates a room with an optional HTTP request body. Fails the test on error. Returns the room ID.
func (c *CSAPI) CreateRoom(t *testing.T, creationContent interface{}) string {
func (c *CSAPI) CreateRoom(t TestLike, creationContent interface{}) string {
t.Helper()
res := c.MustDo(t, "POST", []string{"_matrix", "client", "v3", "createRoom"}, WithJSONBody(t, creationContent))
body := ParseJSON(t, res)
return GetJSONFieldStr(t, body, "room_id")
}

// JoinRoom joins the room ID or alias given, else fails the test. Returns the room ID.
func (c *CSAPI) JoinRoom(t *testing.T, roomIDOrAlias string, serverNames []string) string {
func (c *CSAPI) JoinRoom(t TestLike, roomIDOrAlias string, serverNames []string) string {
t.Helper()
// construct URL query parameters
query := make(url.Values, len(serverNames))
Expand All @@ -113,15 +124,15 @@ func (c *CSAPI) JoinRoom(t *testing.T, roomIDOrAlias string, serverNames []strin
}

// LeaveRoom leaves the room ID, else fails the test.
func (c *CSAPI) LeaveRoom(t *testing.T, roomID string) {
func (c *CSAPI) LeaveRoom(t TestLike, roomID string) {
t.Helper()
// leave the room
body := map[string]interface{}{}
c.MustDo(t, "POST", []string{"_matrix", "client", "v3", "rooms", roomID, "leave"}, WithJSONBody(t, body))
}

// InviteRoom invites userID to the room ID, else fails the test.
func (c *CSAPI) InviteRoom(t *testing.T, roomID string, userID string) {
func (c *CSAPI) InviteRoom(t TestLike, roomID string, userID string) {
t.Helper()
// Invite the user to the room
body := map[string]interface{}{
Expand All @@ -130,19 +141,19 @@ func (c *CSAPI) InviteRoom(t *testing.T, roomID string, userID string) {
c.MustDo(t, "POST", []string{"_matrix", "client", "v3", "rooms", roomID, "invite"}, WithJSONBody(t, body))
}

func (c *CSAPI) GetGlobalAccountData(t *testing.T, eventType string) *http.Response {
func (c *CSAPI) GetGlobalAccountData(t TestLike, eventType string) *http.Response {
return c.MustDo(t, "GET", []string{"_matrix", "client", "v3", "user", c.UserID, "account_data", eventType})
}

func (c *CSAPI) SetGlobalAccountData(t *testing.T, eventType string, content map[string]interface{}) *http.Response {
func (c *CSAPI) SetGlobalAccountData(t TestLike, eventType string, content map[string]interface{}) *http.Response {
return c.MustDo(t, "PUT", []string{"_matrix", "client", "v3", "user", c.UserID, "account_data", eventType}, WithJSONBody(t, content))
}

func (c *CSAPI) GetRoomAccountData(t *testing.T, roomID string, eventType string) *http.Response {
func (c *CSAPI) GetRoomAccountData(t TestLike, roomID string, eventType string) *http.Response {
return c.MustDo(t, "GET", []string{"_matrix", "client", "v3", "user", c.UserID, "rooms", roomID, "account_data", eventType})
}

func (c *CSAPI) SetRoomAccountData(t *testing.T, roomID string, eventType string, content map[string]interface{}) *http.Response {
func (c *CSAPI) SetRoomAccountData(t TestLike, roomID string, eventType string, content map[string]interface{}) *http.Response {
return c.MustDo(t, "PUT", []string{"_matrix", "client", "v3", "user", c.UserID, "rooms", roomID, "account_data", eventType}, WithJSONBody(t, content))
}

Expand All @@ -159,7 +170,7 @@ func (c *CSAPI) SetRoomAccountData(t *testing.T, roomID string, eventType string
// }
//
// Push rules are returned in the same order received from the homeserver.
func (c *CSAPI) GetAllPushRules(t *testing.T) gjson.Result {
func (c *CSAPI) GetAllPushRules(t TestLike) gjson.Result {
t.Helper()

// We have to supply an empty string to the end of this path in order to generate a trailing slash.
Expand All @@ -184,7 +195,7 @@ func (c *CSAPI) GetAllPushRules(t *testing.T) gjson.Result {
// map[string]interface{}{"set_tweak": "highlight"},
// }),
// )
func (c *CSAPI) GetPushRule(t *testing.T, scope string, kind string, ruleID string) gjson.Result {
func (c *CSAPI) GetPushRule(t TestLike, scope string, kind string, ruleID string) gjson.Result {
t.Helper()

res := c.MustDo(t, "GET", []string{"_matrix", "client", "v3", "pushrules", scope, kind, ruleID})
Expand All @@ -203,7 +214,7 @@ func (c *CSAPI) GetPushRule(t *testing.T, scope string, kind string, ruleID stri
// c.SetPushRule(t, "global", "underride", "com.example.rule2", map[string]interface{}{
// "actions": []string{"dont_notify"},
// }, nil, "com.example.rule1")
func (c *CSAPI) SetPushRule(t *testing.T, scope string, kind string, ruleID string, body map[string]interface{}, before string, after string) *http.Response {
func (c *CSAPI) SetPushRule(t TestLike, scope string, kind string, ruleID string, body map[string]interface{}, before string, after string) *http.Response {
t.Helper()

// If the `before` or `after` arguments have been provided, construct same-named query parameters
Expand All @@ -220,7 +231,7 @@ func (c *CSAPI) SetPushRule(t *testing.T, scope string, kind string, ruleID stri

// SendEventUnsynced sends `e` into the room.
// Returns the event ID of the sent event.
func (c *CSAPI) SendEventUnsynced(t *testing.T, roomID string, e b.Event) string {
func (c *CSAPI) SendEventUnsynced(t TestLike, roomID string, e b.Event) string {
t.Helper()
txnID := int(atomic.AddInt64(&c.txnID, 1))
return c.SendEventUnsyncedWithTxnID(t, roomID, e, strconv.Itoa(txnID))
Expand All @@ -229,7 +240,7 @@ func (c *CSAPI) SendEventUnsynced(t *testing.T, roomID string, e b.Event) string
// SendEventUnsyncedWithTxnID sends `e` into the room with a prescribed transaction ID.
// This is useful for writing tests that interrogate transaction semantics.
// Returns the event ID of the sent event.
func (c *CSAPI) SendEventUnsyncedWithTxnID(t *testing.T, roomID string, e b.Event, txnID string) string {
func (c *CSAPI) SendEventUnsyncedWithTxnID(t TestLike, roomID string, e b.Event, txnID string) string {
t.Helper()
paths := []string{"_matrix", "client", "v3", "rooms", roomID, "send", e.Type, txnID}
if e.StateKey != nil {
Expand All @@ -243,7 +254,7 @@ func (c *CSAPI) SendEventUnsyncedWithTxnID(t *testing.T, roomID string, e b.Even

// SendEventSynced sends `e` into the room and waits for its event ID to come down /sync.
// Returns the event ID of the sent event.
func (c *CSAPI) SendEventSynced(t *testing.T, roomID string, e b.Event) string {
func (c *CSAPI) SendEventSynced(t TestLike, roomID string, e b.Event) string {
t.Helper()
eventID := c.SendEventUnsynced(t, roomID, e)
t.Logf("SendEventSynced waiting for event ID %s", eventID)
Expand All @@ -254,7 +265,7 @@ func (c *CSAPI) SendEventSynced(t *testing.T, roomID string, e b.Event) string {
}

// SendRedaction sends a redaction request. Will fail if the returned HTTP request code is not 200
func (c *CSAPI) SendRedaction(t *testing.T, roomID string, e b.Event, eventID string) string {
func (c *CSAPI) SendRedaction(t TestLike, roomID string, e b.Event, eventID string) string {
t.Helper()
txnID := int(atomic.AddInt64(&c.txnID, 1))
paths := []string{"_matrix", "client", "v3", "rooms", roomID, "redact", eventID, strconv.Itoa(txnID)}
Expand All @@ -264,7 +275,7 @@ func (c *CSAPI) SendRedaction(t *testing.T, roomID string, e b.Event, eventID st
}

// GetCapbabilities queries the server's capabilities
func (c *CSAPI) GetCapabilities(t *testing.T) []byte {
func (c *CSAPI) GetCapabilities(t TestLike) []byte {
t.Helper()
res := c.MustDo(t, "GET", []string{"_matrix", "client", "v3", "capabilities"})
body, err := io.ReadAll(res.Body)
Expand All @@ -275,7 +286,7 @@ func (c *CSAPI) GetCapabilities(t *testing.T) []byte {
}

// GetDefaultRoomVersion returns the server's default room version
func (c *CSAPI) GetDefaultRoomVersion(t *testing.T) gomatrixserverlib.RoomVersion {
func (c *CSAPI) GetDefaultRoomVersion(t TestLike) gomatrixserverlib.RoomVersion {
t.Helper()
capabilities := c.GetCapabilities(t)
defaultVersion := gjson.GetBytes(capabilities, `capabilities.m\.room_versions.default`)
Expand Down Expand Up @@ -310,7 +321,7 @@ func WithContentType(cType string) RequestOpt {
}

// WithJSONBody sets the HTTP request body to the JSON serialised form of `obj`
func WithJSONBody(t *testing.T, obj interface{}) RequestOpt {
func WithJSONBody(t TestLike, obj interface{}) RequestOpt {
return func(req *http.Request) {
t.Helper()
b, err := json.Marshal(obj)
Expand Down Expand Up @@ -341,7 +352,7 @@ func WithRetryUntil(timeout time.Duration, untilFn func(res *http.Response) bool
}

// MustDo is the same as Do but fails the test if the returned HTTP response code is not 2xx.
func (c *CSAPI) MustDo(t *testing.T, method string, paths []string, opts ...RequestOpt) *http.Response {
func (c *CSAPI) MustDo(t TestLike, method string, paths []string, opts ...RequestOpt) *http.Response {
t.Helper()
res := c.Do(t, method, paths, opts...)
if res.StatusCode < 200 || res.StatusCode >= 300 {
Expand All @@ -365,7 +376,7 @@ func (c *CSAPI) MustDo(t *testing.T, method string, paths []string, opts ...Requ
// match.JSONKeyEqual("errcode", "M_INVALID_USERNAME"),
// },
// })
func (c *CSAPI) Do(t *testing.T, method string, paths []string, opts ...RequestOpt) *http.Response {
func (c *CSAPI) Do(t TestLike, method string, paths []string, opts ...RequestOpt) *http.Response {
t.Helper()
for i := range paths {
paths[i] = url.PathEscape(paths[i])
Expand Down Expand Up @@ -450,7 +461,7 @@ func (c *CSAPI) Do(t *testing.T, method string, paths []string, opts ...RequestO
}

// NewLoggedClient returns an http.Client which logs requests/responses
func NewLoggedClient(t *testing.T, hsName string, cli *http.Client) *http.Client {
func NewLoggedClient(t TestLike, hsName string, cli *http.Client) *http.Client {
t.Helper()
if cli == nil {
cli = &http.Client{
Expand All @@ -466,7 +477,7 @@ func NewLoggedClient(t *testing.T, hsName string, cli *http.Client) *http.Client
}

type loggedRoundTripper struct {
t *testing.T
t TestLike
hsName string
wrap http.RoundTripper
}
Expand All @@ -483,7 +494,7 @@ func (t *loggedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
}

// GetJSONFieldStr extracts a value from a byte-encoded JSON body given a search key
func GetJSONFieldStr(t *testing.T, body []byte, wantKey string) string {
func GetJSONFieldStr(t TestLike, body []byte, wantKey string) string {
t.Helper()
res := gjson.GetBytes(body, wantKey)
if !res.Exists() {
Expand All @@ -495,7 +506,7 @@ func GetJSONFieldStr(t *testing.T, body []byte, wantKey string) string {
return res.Str
}

func GetJSONFieldStringArray(t *testing.T, body []byte, wantKey string) []string {
func GetJSONFieldStringArray(t TestLike, body []byte, wantKey string) []string {
t.Helper()

res := gjson.GetBytes(body, wantKey)
Expand All @@ -519,7 +530,7 @@ func GetJSONFieldStringArray(t *testing.T, body []byte, wantKey string) []string
}

// ParseJSON parses a JSON-encoded HTTP Response body into a byte slice
func ParseJSON(t *testing.T, res *http.Response) []byte {
func ParseJSON(t TestLike, res *http.Response) []byte {
t.Helper()
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
Expand Down Expand Up @@ -569,7 +580,7 @@ func SplitMxc(mxcUri string) (string, string) {
//
// The messages parameter is nested as follows:
// user_id -> device_id -> content (map[string]interface{})
func (c *CSAPI) SendToDeviceMessages(t *testing.T, evType string, messages map[string]map[string]map[string]interface{}) {
func (c *CSAPI) SendToDeviceMessages(t TestLike, evType string, messages map[string]map[string]map[string]interface{}) {
t.Helper()
txnID := int(atomic.AddInt64(&c.txnID, 1))
c.MustDo(
Expand Down
5 changes: 2 additions & 3 deletions client/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"net/url"
"strings"
"testing"
"time"

"github.com/tidwall/gjson"
Expand Down Expand Up @@ -86,7 +85,7 @@ type SyncReq struct {
//
// Will time out after CSAPI.SyncUntilTimeout. Returns the `next_batch` token from the final
// response.
func (c *CSAPI) MustSyncUntil(t *testing.T, syncReq SyncReq, checks ...SyncCheckOpt) string {
func (c *CSAPI) MustSyncUntil(t TestLike, syncReq SyncReq, checks ...SyncCheckOpt) string {
t.Helper()
start := time.Now()
numResponsesReturned := 0
Expand Down Expand Up @@ -139,7 +138,7 @@ func (c *CSAPI) MustSyncUntil(t *testing.T, syncReq SyncReq, checks ...SyncCheck
//
// Fails the test if the /sync request does not return 200 OK.
// Returns the top-level parsed /sync response JSON as well as the next_batch token from the response.
func (c *CSAPI) MustSync(t *testing.T, syncReq SyncReq) (gjson.Result, string) {
func (c *CSAPI) MustSync(t TestLike, syncReq SyncReq) (gjson.Result, string) {
t.Helper()
query := url.Values{
"timeout": []string{"1000"},
Expand Down
2 changes: 1 addition & 1 deletion tests/csapi/apidoc_room_forget_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestRoomForget(t *testing.T) {
},
})
alice.LeaveRoom(t, roomID)
alice.MustDo(t, "POST", []string{"_matrix", "client", "v3", "rooms", roomID, "forget"}, , client.WithJSONBody(t, struct{}{}))
alice.MustDo(t, "POST", []string{"_matrix", "client", "v3", "rooms", roomID, "forget"}, client.WithJSONBody(t, struct{}{}))
res := alice.Do(t, "GET", []string{"_matrix", "client", "v3", "rooms", roomID, "messages"})
must.MatchResponse(t, res, match.HTTPResponse{
StatusCode: http.StatusForbidden,
Expand Down

0 comments on commit b8ef7fe

Please sign in to comment.