diff --git a/internal/client/client.go b/internal/client/client.go index 9a52510d..71bbb5fd 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -335,24 +335,27 @@ func (c *CSAPI) MustSync(t *testing.T, syncReq SyncReq) (gjson.Result, string) { // check functions return no error. Returns the final/latest since token. // // Initial /sync example: (no since token) -// bob.InviteRoom(t, roomID, alice.UserID) -// alice.JoinRoom(t, roomID, nil) -// alice.MustSyncUntil(t, client.SyncReq{}, client.SyncJoinedTo(alice.UserID, roomID)) +// +// bob.InviteRoom(t, roomID, alice.UserID) +// alice.JoinRoom(t, roomID, nil) +// alice.MustSyncUntil(t, client.SyncReq{}, client.SyncJoinedTo(alice.UserID, roomID)) // // Incremental /sync example: (test controls since token) -// since := alice.MustSyncUntil(t, client.SyncReq{TimeoutMillis: "0"}) // get a since token -// bob.InviteRoom(t, roomID, alice.UserID) -// since = alice.MustSyncUntil(t, client.SyncReq{Since: since}, client.SyncInvitedTo(alice.UserID, roomID)) -// alice.JoinRoom(t, roomID, nil) -// alice.MustSyncUntil(t, client.SyncReq{Since: since}, client.SyncJoinedTo(alice.UserID, roomID)) +// +// since := alice.MustSyncUntil(t, client.SyncReq{TimeoutMillis: "0"}) // get a since token +// bob.InviteRoom(t, roomID, alice.UserID) +// since = alice.MustSyncUntil(t, client.SyncReq{Since: since}, client.SyncInvitedTo(alice.UserID, roomID)) +// alice.JoinRoom(t, roomID, nil) +// alice.MustSyncUntil(t, client.SyncReq{Since: since}, client.SyncJoinedTo(alice.UserID, roomID)) // // Checking multiple parts of /sync: -// alice.MustSyncUntil( -// t, client.SyncReq{}, -// client.SyncJoinedTo(alice.UserID, roomID), -// client.SyncJoinedTo(alice.UserID, roomID2), -// client.SyncJoinedTo(alice.UserID, roomID3), -// ) +// +// alice.MustSyncUntil( +// t, client.SyncReq{}, +// client.SyncJoinedTo(alice.UserID, roomID), +// client.SyncJoinedTo(alice.UserID, roomID2), +// client.SyncJoinedTo(alice.UserID, roomID3), +// ) // // Check functions are unordered and independent. Once a check function returns true it is removed // from the list of checks and won't be called again. @@ -438,7 +441,81 @@ func (c *CSAPI) LoginUser(t *testing.T, localpart, password string) (userID, acc return userID, accessToken, deviceID } -//RegisterUser will register the user with given parameters and +// LoginUserWithDeviceID will log in to a homeserver on an existing device +func (c *CSAPI) LoginUserWithDeviceID(t *testing.T, localpart, password, deviceID string) (userID, accessToken string) { + t.Helper() + reqBody := map[string]interface{}{ + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": localpart, + }, + "device_id": deviceID, + "password": password, + "type": "m.login.password", + } + res := c.MustDoFunc(t, "POST", []string{"_matrix", "client", "v3", "login"}, WithJSONBody(t, reqBody)) + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("unable to read response body: %v", err) + } + + userID = gjson.GetBytes(body, "user_id").Str + accessToken = gjson.GetBytes(body, "access_token").Str + if gjson.GetBytes(body, "device_id").Str != deviceID { + t.Fatalf("device_id returned by login does not match the one requested") + } + return userID, accessToken +} + +// LoginUserWithRefreshToken will log in to a homeserver, with refresh token enabled, +// and create a new device on an existing user. +func (c *CSAPI) LoginUserWithRefreshToken(t *testing.T, localpart, password string) (userID, accessToken, refreshToken, deviceID string, expiresInMs int64) { + t.Helper() + reqBody := map[string]interface{}{ + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": localpart, + }, + "password": password, + "type": "m.login.password", + "refresh_token": true, + } + res := c.MustDoFunc(t, "POST", []string{"_matrix", "client", "v3", "login"}, WithJSONBody(t, reqBody)) + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("unable to read response body: %v", err) + } + + userID = gjson.GetBytes(body, "user_id").Str + accessToken = gjson.GetBytes(body, "access_token").Str + deviceID = gjson.GetBytes(body, "device_id").Str + refreshToken = gjson.GetBytes(body, "refresh_token").Str + expiresInMs = gjson.GetBytes(body, "expires_in_ms").Int() + return userID, accessToken, refreshToken, deviceID, expiresInMs +} + +// RefreshToken will consume a refresh token and return a new access token and refresh token. +func (c *CSAPI) ConsumeRefreshToken(t *testing.T, refreshToken string) (newAccessToken, newRefreshToken string, expiresInMs int64) { + t.Helper() + reqBody := map[string]interface{}{ + "refresh_token": refreshToken, + } + res := c.MustDoFunc(t, "POST", []string{"_matrix", "client", "v3", "refresh"}, WithJSONBody(t, reqBody)) + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("unable to read response body: %v", err) + } + + newAccessToken = gjson.GetBytes(body, "access_token").Str + newRefreshToken = gjson.GetBytes(body, "refresh_token").Str + expiresInMs = gjson.GetBytes(body, "expires_in_ms").Int() + return newAccessToken, newRefreshToken, expiresInMs +} + +// RegisterUser will register the user with given parameters and // return user ID & access token, and fail the test on network error func (c *CSAPI) RegisterUser(t *testing.T, localpart, password string) (userID, accessToken, deviceID string) { t.Helper() @@ -598,12 +675,13 @@ func (c *CSAPI) MustDoFunc(t *testing.T, method string, paths []string, opts ... // // Fails the test if an HTTP request could not be made or if there was a network error talking to the // server. To do assertions on the HTTP response, see the `must` package. For example: -// must.MatchResponse(t, res, match.HTTPResponse{ -// StatusCode: 400, -// JSON: []match.JSON{ -// match.JSONKeyEqual("errcode", "M_INVALID_USERNAME"), -// }, -// }) +// +// must.MatchResponse(t, res, match.HTTPResponse{ +// StatusCode: 400, +// JSON: []match.JSON{ +// match.JSONKeyEqual("errcode", "M_INVALID_USERNAME"), +// }, +// }) func (c *CSAPI) DoFunc(t *testing.T, method string, paths []string, opts ...RequestOpt) *http.Response { t.Helper() for i := range paths { diff --git a/internal/docker/deployment.go b/internal/docker/deployment.go index cefb2404..cdd79356 100644 --- a/internal/docker/deployment.go +++ b/internal/docker/deployment.go @@ -91,7 +91,7 @@ func (d *Deployment) Client(t *testing.T, hsName, userID string) *client.CSAPI { // NewUser creates a new user as a convenience method to RegisterUser. // -//It registers the user with a deterministic password, and without admin privileges. +// It registers the user with a deterministic password, and without admin privileges. func (d *Deployment) NewUser(t *testing.T, localpart, hs string) *client.CSAPI { return d.RegisterUser(t, hs, localpart, "complement_meets_min_pasword_req_"+localpart, false) } diff --git a/tests/csapi/txnid_scope_test.go b/tests/csapi/txnid_scope_test.go new file mode 100644 index 00000000..7276bc02 --- /dev/null +++ b/tests/csapi/txnid_scope_test.go @@ -0,0 +1,95 @@ +package csapi_tests + +import ( + "testing" + + "github.com/matrix-org/complement/internal/b" + "github.com/matrix-org/complement/internal/client" + "github.com/tidwall/gjson" +) + +// TestTxnAfterRefresh tests that when a client refreshes its access token, +// it still gets back a transaction ID in the sync response. +func TestTxnAfterRefresh(t *testing.T) { + deployment := Deploy(t, b.BlueprintCleanHS) + defer deployment.Destroy(t) + + deployment.RegisterUser(t, "hs1", "alice", "password", false) + + c := deployment.Client(t, "hs1", "") + + var refreshToken string + c.UserID, c.AccessToken, refreshToken, c.DeviceID, _ = c.LoginUserWithRefreshToken(t, "alice", "password") + + // Create a room where we can send events. + roomID := c.CreateRoom(t, map[string]interface{}{}) + + // Let's send an event, and wait for it to appear in the sync. + eventID := c.SendEventUnsynced(t, roomID, b.Event{ + Type: "m.room.message", + Content: map[string]interface{}{ + "msgtype": "m.text", + "body": "first", + }, + }) + + // When syncing, we should find the event and it should have a transaction ID. + c.MustSyncUntil(t, client.SyncReq{}, client.SyncTimelineHas(roomID, func(r gjson.Result) bool { + return r.Get("event_id").Str == eventID && r.Get("unsigned.transaction_id").Exists() + })) + + // Now do the same, but refresh the token before syncing. + eventID = c.SendEventUnsynced(t, roomID, b.Event{ + Type: "m.room.message", + Content: map[string]interface{}{ + "msgtype": "m.text", + "body": "second", + }, + }) + + // Use the refresh token to get a new access token. + c.AccessToken, refreshToken, _ = c.ConsumeRefreshToken(t, refreshToken) + + // When syncing, we should find the event and it should also have a transaction ID. + c.MustSyncUntil(t, client.SyncReq{}, client.SyncTimelineHas(roomID, func(r gjson.Result) bool { + return r.Get("event_id").Str == eventID && r.Get("unsigned.transaction_id").Exists() + })) +} + +// TestTxnScope tests that transaction IDs are scoped to the access token, not the device +func TestTxnScope(t *testing.T) { + deployment := Deploy(t, b.BlueprintCleanHS) + defer deployment.Destroy(t) + + deployment.RegisterUser(t, "hs1", "alice", "password", false) + + // Create a first client, which allocates a device ID. + c1 := deployment.Client(t, "hs1", "") + c1.UserID, c1.AccessToken, c1.DeviceID = c1.LoginUser(t, "alice", "password") + // Create a room where we can send events. + roomID := c1.CreateRoom(t, map[string]interface{}{}) + + // Let's send an event, and wait for it to appear in the timeline. + eventID := c1.SendEventUnsynced(t, roomID, b.Event{ + Type: "m.room.message", + Content: map[string]interface{}{ + "msgtype": "m.text", + "body": "first", + }, + }) + + // When syncing, we should find the event and it should have a transaction ID on the first client. + c1.MustSyncUntil(t, client.SyncReq{}, client.SyncTimelineHas(roomID, func(r gjson.Result) bool { + return r.Get("event_id").Str == eventID && r.Get("unsigned.transaction_id").Exists() + })) + + // Create a second client, inheriting the same device ID. + c2 := deployment.Client(t, "hs1", "") + c2.UserID, c2.AccessToken = c2.LoginUserWithDeviceID(t, "alice", "password", c1.DeviceID) + c2.DeviceID = c1.DeviceID + + // When syncing, we should find the event and it should *not* have a transaction ID on the second client. + c2.MustSyncUntil(t, client.SyncReq{}, client.SyncTimelineHas(roomID, func(r gjson.Result) bool { + return r.Get("event_id").Str == eventID && !r.Get("unsigned.transaction_id").Exists() + })) +}