Skip to content

Commit

Permalink
Add GenerateOneTimeKeys, Prefix Unsafe_ to unsafe functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kegsay committed Oct 3, 2023
1 parent b8ef7fe commit 81193ac
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 124 deletions.
67 changes: 59 additions & 8 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/matrix-org/gomatrixserverlib"
"github.com/tidwall/gjson"
"maunium.net/go/mautrix/crypto/olm"

"github.com/matrix-org/complement/internal/b"
)
Expand All @@ -36,10 +37,10 @@ type TestLike interface {
Fatalf(msg string, args ...interface{})
}

type CtxKey string
type ctxKey string

const (
CtxKeyWithRetryUntil CtxKey = "complement_retry_until" // contains *retryUntilParams
CtxKeyWithRetryUntil ctxKey = "complement_retry_until" // contains *retryUntilParams
)

type retryUntilParams struct {
Expand Down Expand Up @@ -229,18 +230,20 @@ func (c *CSAPI) SetPushRule(t TestLike, scope string, kind string, ruleID string
return c.MustDo(t, "PUT", []string{"_matrix", "client", "v3", "pushrules", scope, kind, ruleID}, WithJSONBody(t, body), WithQueries(queryParams))
}

// SendEventUnsynced sends `e` into the room.
// Unsafe_SendEventUnsynced sends `e` into the room. This function is UNSAFE as it does not wait
// for the event to be fully processed. This can cause flakey tests. Prefer `SendEventSynced`.
// Returns the event ID of the sent event.
func (c *CSAPI) SendEventUnsynced(t TestLike, roomID string, e b.Event) string {
func (c *CSAPI) Unsafe_SendEventUnsynced(t TestLike, roomID string, e b.Event) string {
t.Helper()
txnID := int(atomic.AddInt64(&c.txnID, 1))
return c.SendEventUnsyncedWithTxnID(t, roomID, e, strconv.Itoa(txnID))
return c.Unsafe_SendEventUnsyncedWithTxnID(t, roomID, e, strconv.Itoa(txnID))
}

// SendEventUnsyncedWithTxnID sends `e` into the room with a prescribed transaction ID.
// This is useful for writing tests that interrogate transaction semantics.
// This is useful for writing tests that interrogate transaction semantics. This function is UNSAFE
// as it does not wait for the event to be fully processed. This can cause flakey tests. Prefer `SendEventSynced`.
// Returns the event ID of the sent event.
func (c *CSAPI) SendEventUnsyncedWithTxnID(t TestLike, roomID string, e b.Event, txnID string) string {
func (c *CSAPI) Unsafe_SendEventUnsyncedWithTxnID(t TestLike, roomID string, e b.Event, txnID string) string {
t.Helper()
paths := []string{"_matrix", "client", "v3", "rooms", roomID, "send", e.Type, txnID}
if e.StateKey != nil {
Expand All @@ -256,7 +259,7 @@ func (c *CSAPI) SendEventUnsyncedWithTxnID(t TestLike, roomID string, e b.Event,
// Returns the event ID of the sent event.
func (c *CSAPI) SendEventSynced(t TestLike, roomID string, e b.Event) string {
t.Helper()
eventID := c.SendEventUnsynced(t, roomID, e)
eventID := c.Unsafe_SendEventUnsynced(t, roomID, e)
t.Logf("SendEventSynced waiting for event ID %s", eventID)
c.MustSyncUntil(t, SyncReq{}, SyncTimelineHas(roomID, func(r gjson.Result) bool {
return r.Get("event_id").Str == eventID
Expand Down Expand Up @@ -298,6 +301,54 @@ func (c *CSAPI) GetDefaultRoomVersion(t TestLike) gomatrixserverlib.RoomVersion
return gomatrixserverlib.RoomVersion(defaultVersion.Str)
}

func (c *CSAPI) GenerateOneTimeKeys(t TestLike, otkCount uint) (deviceKeys map[string]interface{}, oneTimeKeys map[string]interface{}) {
t.Helper()
account := olm.NewAccount()
ed25519Key, curveKey := account.IdentityKeys()

ed25519KeyID := fmt.Sprintf("ed25519:%s", c.DeviceID)
curveKeyID := fmt.Sprintf("curve25519:%s", c.DeviceID)

deviceKeys = map[string]interface{}{
"user_id": c.UserID,
"device_id": c.DeviceID,
"algorithms": []interface{}{"m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"},
"keys": map[string]interface{}{
ed25519KeyID: ed25519Key.String(),
curveKeyID: curveKey.String(),
},
}

signature, _ := account.SignJSON(deviceKeys)

deviceKeys["signatures"] = map[string]interface{}{
c.UserID: map[string]interface{}{
ed25519KeyID: signature,
},
}

account.GenOneTimeKeys(otkCount)
oneTimeKeys = map[string]interface{}{}

for kid, key := range account.OneTimeKeys() {
keyID := fmt.Sprintf("signed_curve25519:%s", kid)
keyMap := map[string]interface{}{
"key": key.String(),
}

signature, _ = account.SignJSON(keyMap)

keyMap["signatures"] = map[string]interface{}{
c.UserID: map[string]interface{}{
ed25519KeyID: signature,
},
}

oneTimeKeys[keyID] = keyMap
}
return deviceKeys, oneTimeKeys
}

// WithRawBody sets the HTTP request body to `body`
func WithRawBody(body []byte) RequestOpt {
return func(req *http.Request) {
Expand Down
2 changes: 1 addition & 1 deletion tests/csapi/keychanges_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestKeyChangesLocal(t *testing.T) {

func mustUploadKeys(t *testing.T, user *client.CSAPI) {
t.Helper()
deviceKeys, oneTimeKeys := generateKeys(t, user, 5)
deviceKeys, oneTimeKeys := user.GenerateOneTimeKeys(t, 5)
reqBody := client.WithJSONBody(t, map[string]interface{}{
"device_keys": deviceKeys,
"one_time_keys": oneTimeKeys,
Expand Down
10 changes: 5 additions & 5 deletions tests/csapi/room_relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func TestRelationsPaginationSync(t *testing.T) {
roomID := alice.CreateRoom(t, map[string]interface{}{"preset": "public_chat"})
_, token := alice.MustSync(t, client.SyncReq{})

rootEventID := alice.SendEventUnsynced(t, roomID, b.Event{
rootEventID := alice.Unsafe_SendEventUnsynced(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
Expand All @@ -227,9 +227,9 @@ func TestRelationsPaginationSync(t *testing.T) {
})

// Create some related events.
event_id := ""
eventID := ""
for i := 0; i < 5; i++ {
event_id = alice.SendEventUnsynced(t, roomID, b.Event{
eventID = alice.Unsafe_SendEventUnsynced(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
Expand All @@ -245,13 +245,13 @@ func TestRelationsPaginationSync(t *testing.T) {

// Sync and keep the token.
nextBatch := alice.MustSyncUntil(t, client.SyncReq{Since: token}, client.SyncTimelineHas(roomID, func(r gjson.Result) bool {
return r.Get("event_id").Str == event_id
return r.Get("event_id").Str == eventID
}))

// Create more related events.
event_ids := [5]string{}
for i := 0; i < 5; i++ {
event_ids[i] = alice.SendEventUnsynced(t, roomID, b.Event{
event_ids[i] = alice.Unsafe_SendEventUnsynced(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
Expand Down
20 changes: 10 additions & 10 deletions tests/csapi/txnid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestTxnInEvent(t *testing.T) {

txnId := "abcdefg"
// Let's send an event, and wait for it to appear in the timeline.
eventID := c.SendEventUnsyncedWithTxnID(t, roomID, b.Event{
eventID := c.Unsafe_SendEventUnsyncedWithTxnID(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
Expand Down Expand Up @@ -82,7 +82,7 @@ func TestTxnScopeOnLocalEcho(t *testing.T) {

txnId := "abdefgh"
// Let's send an event, and wait for it to appear in the timeline.
eventID := c1.SendEventUnsyncedWithTxnID(t, roomID, b.Event{
eventID := c1.Unsafe_SendEventUnsyncedWithTxnID(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
Expand Down Expand Up @@ -128,15 +128,15 @@ func TestTxnIdempotencyScopedToDevice(t *testing.T) {
},
}
// send an event with set txnId
eventID1 := c1.SendEventUnsyncedWithTxnID(t, roomID, event, txnId)
eventID1 := c1.Unsafe_SendEventUnsyncedWithTxnID(t, roomID, event, txnId)

// Create a second client, inheriting the first device ID.
c2 := deployment.Client(t, "hs1", "")
c2.UserID, c2.AccessToken, c2.DeviceID = c2.LoginUser(t, "alice", "password", client.WithDeviceID(c1.DeviceID))
must.EqualStr(t, c1.DeviceID, c2.DeviceID, "Device ID should be the same")

// send another event with the same txnId via the second client
eventID2 := c2.SendEventUnsyncedWithTxnID(t, roomID, event, txnId)
eventID2 := c2.Unsafe_SendEventUnsyncedWithTxnID(t, roomID, event, txnId)

// the two events should have the same event IDs as they came from the same device
must.EqualStr(t, eventID2, eventID1, "Expected eventID1 and eventID2 to be the same from two clients sharing the same device ID")
Expand Down Expand Up @@ -178,20 +178,20 @@ func TestTxnIdempotency(t *testing.T) {
}

// we send the event and get an event ID back
eventID1 := c1.SendEventUnsyncedWithTxnID(t, roomID1, event1, txnId)
eventID1 := c1.Unsafe_SendEventUnsyncedWithTxnID(t, roomID1, event1, txnId)

// we send the identical event again and should get back the same event ID
eventID2 := c1.SendEventUnsyncedWithTxnID(t, roomID1, event1, txnId)
eventID2 := c1.Unsafe_SendEventUnsyncedWithTxnID(t, roomID1, event1, txnId)

must.EqualStr(t, eventID2, eventID1, "Expected eventID1 and eventID2 to be the same, but they were not")

// even if we change the content we should still get back the same event ID as transaction ID is the same
eventID3 := c1.SendEventUnsyncedWithTxnID(t, roomID1, event2, txnId)
eventID3 := c1.Unsafe_SendEventUnsyncedWithTxnID(t, roomID1, event2, txnId)

must.EqualStr(t, eventID3, eventID1, "Expected eventID3 and eventID2 to be the same even with different content, but they were not")

// if we change the room ID we should be able to use the same transaction ID
eventID4 := c1.SendEventUnsyncedWithTxnID(t, roomID2, event1, txnId)
eventID4 := c1.Unsafe_SendEventUnsyncedWithTxnID(t, roomID2, event1, txnId)

must.NotEqualStr(t, eventID4, eventID3, "Expected eventID4 and eventID3 to be different, but they were not")
}
Expand All @@ -217,7 +217,7 @@ func TestTxnIdWithRefreshToken(t *testing.T) {

txnId := "abcdef"
// We send an event
eventID1 := c.SendEventUnsyncedWithTxnID(t, roomID, b.Event{
eventID1 := c.Unsafe_SendEventUnsyncedWithTxnID(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
Expand All @@ -233,7 +233,7 @@ func TestTxnIdWithRefreshToken(t *testing.T) {
c.MustSyncUntil(t, client.SyncReq{}, mustHaveTransactionIDForEvent(t, roomID, eventID1, txnId))

// We try sending the event again with the same transaction ID
eventID2 := c.SendEventUnsyncedWithTxnID(t, roomID, b.Event{
eventID2 := c.Unsafe_SendEventUnsyncedWithTxnID(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
Expand Down
52 changes: 2 additions & 50 deletions tests/csapi/upload_keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"testing"

"github.com/tidwall/gjson"
"maunium.net/go/mautrix/crypto/olm"

"github.com/matrix-org/complement/client"
"github.com/matrix-org/complement/internal/b"
Expand All @@ -23,7 +22,8 @@ func TestUploadKey(t *testing.T) {
alice := deployment.Client(t, "hs1", "@alice:hs1")
bob := deployment.Client(t, "hs1", "@bob:hs1")

deviceKeys, oneTimeKeys := generateKeys(t, alice, 1)
deviceKeys, oneTimeKeys := alice.GenerateOneTimeKeys(t, 1)

t.Run("Parallel", func(t *testing.T) {
// sytest: Can upload device keys
t.Run("Can upload device keys", func(t *testing.T) {
Expand Down Expand Up @@ -172,51 +172,3 @@ func TestUploadKey(t *testing.T) {
})
})
}

func generateKeys(t *testing.T, user *client.CSAPI, otkCount uint) (deviceKeys map[string]interface{}, oneTimeKeys map[string]interface{}) {
t.Helper()
account := olm.NewAccount()
ed25519Key, curveKey := account.IdentityKeys()

ed25519KeyID := fmt.Sprintf("ed25519:%s", user.DeviceID)
curveKeyID := fmt.Sprintf("curve25519:%s", user.DeviceID)

deviceKeys = map[string]interface{}{
"user_id": user.UserID,
"device_id": user.DeviceID,
"algorithms": []interface{}{"m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"},
"keys": map[string]interface{}{
ed25519KeyID: ed25519Key.String(),
curveKeyID: curveKey.String(),
},
}

signature, _ := account.SignJSON(deviceKeys)

deviceKeys["signatures"] = map[string]interface{}{
user.UserID: map[string]interface{}{
ed25519KeyID: signature,
},
}

account.GenOneTimeKeys(otkCount)
oneTimeKeys = map[string]interface{}{}

for kid, key := range account.OneTimeKeys() {
keyID := fmt.Sprintf("signed_curve25519:%s", kid)
keyMap := map[string]interface{}{
"key": key.String(),
}

signature, _ = account.SignJSON(keyMap)

keyMap["signatures"] = map[string]interface{}{
user.UserID: map[string]interface{}{
ed25519KeyID: signature,
},
}

oneTimeKeys[keyID] = keyMap
}
return deviceKeys, oneTimeKeys
}
51 changes: 1 addition & 50 deletions tests/federation_upload_keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"testing"

"github.com/tidwall/gjson"
"maunium.net/go/mautrix/crypto/olm"

"github.com/matrix-org/complement/client"
"github.com/matrix-org/complement/internal/b"
Expand All @@ -25,7 +24,7 @@ func TestFederationKeyUploadQuery(t *testing.T) {
// Do an initial sync so that we can see the changes come down sync.
_, nextBatchBeforeKeyUpload := bob.MustSync(t, client.SyncReq{})

deviceKeys, oneTimeKeys := generateKeys(t, alice, 1)
deviceKeys, oneTimeKeys := alice.GenerateOneTimeKeys(t, 1)
// Upload keys
reqBody := client.WithJSONBody(t, map[string]interface{}{
"device_keys": deviceKeys,
Expand Down Expand Up @@ -134,51 +133,3 @@ func TestFederationKeyUploadQuery(t *testing.T) {
})
})
}

func generateKeys(t *testing.T, user *client.CSAPI, otkCount uint) (deviceKeys map[string]interface{}, oneTimeKeys map[string]interface{}) {
t.Helper()
account := olm.NewAccount()
ed25519Key, curveKey := account.IdentityKeys()

ed25519KeyID := fmt.Sprintf("ed25519:%s", user.DeviceID)
curveKeyID := fmt.Sprintf("curve25519:%s", user.DeviceID)

deviceKeys = map[string]interface{}{
"user_id": user.UserID,
"device_id": user.DeviceID,
"algorithms": []interface{}{"m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"},
"keys": map[string]interface{}{
ed25519KeyID: ed25519Key.String(),
curveKeyID: curveKey.String(),
},
}

signature, _ := account.SignJSON(deviceKeys)

deviceKeys["signatures"] = map[string]interface{}{
user.UserID: map[string]interface{}{
ed25519KeyID: signature,
},
}

account.GenOneTimeKeys(otkCount)
oneTimeKeys = map[string]interface{}{}

for kid, key := range account.OneTimeKeys() {
keyID := fmt.Sprintf("signed_curve25519:%s", kid)
keyMap := map[string]interface{}{
"key": key.String(),
}

signature, _ = account.SignJSON(keyMap)

keyMap["signatures"] = map[string]interface{}{
user.UserID: map[string]interface{}{
ed25519KeyID: signature,
},
}

oneTimeKeys[keyID] = keyMap
}
return deviceKeys, oneTimeKeys
}

0 comments on commit 81193ac

Please sign in to comment.