From 2cbeaae959adea9bc532bfeafcb75736fa9499fe Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Thu, 21 Jul 2022 20:01:27 +0100 Subject: [PATCH] Allow explicitly specified `/state` and `/state_ids` requests to complete --- ...federation_room_join_partial_state_test.go | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/tests/federation_room_join_partial_state_test.go b/tests/federation_room_join_partial_state_test.go index 3210c820..5b3ae3b6 100644 --- a/tests/federation_room_join_partial_state_test.go +++ b/tests/federation_room_join_partial_state_test.go @@ -592,6 +592,8 @@ type partialStateJoinResult struct { ServerRoom *federation.ServerRoom fedStateIdsRequestReceivedWaiter *Waiter fedStateIdsSendResponseWaiter *Waiter + // the set of events for which we will not block `/state` or `/state_ids` requests. + fedStateIdsAllowedEvents map[string]bool } // beginPartialStateJoin spins up a room on a complement server, @@ -627,6 +629,7 @@ func beginPartialStateJoin(t *testing.T, deployment *docker.Deployment, joiningU // some things for orchestration result.fedStateIdsRequestReceivedWaiter = NewWaiter() result.fedStateIdsSendResponseWaiter = NewWaiter() + result.fedStateIdsAllowedEvents = make(map[string]bool) // create the room on the complement server, with charlie and derek as members roomVer := joiningUser.GetDefaultRoomVersion(t) @@ -642,10 +645,17 @@ func beginPartialStateJoin(t *testing.T, deployment *docker.Deployment, joiningU // register a handler for /state_ids requests, which finishes fedStateIdsRequestReceivedWaiter, then // waits for fedStateIdsSendResponseWaiter and sends a reply - handleStateIdsRequests(t, result.Server, result.ServerRoom, result.fedStateIdsRequestReceivedWaiter, result.fedStateIdsSendResponseWaiter) + handleStateIdsRequests( + t, + result.Server, + result.ServerRoom, + result.fedStateIdsRequestReceivedWaiter, + result.fedStateIdsSendResponseWaiter, + result.fedStateIdsAllowedEvents, + ) // a handler for /state requests, which sends a sensible response - handleStateRequests(t, result.Server, result.ServerRoom, nil, nil) + handleStateRequests(t, result.Server, result.ServerRoom, nil, nil, nil) // have joiningUser join the room by room ID. joiningUser.JoinRoom(t, result.ServerRoom.RoomID, []string{result.Server.ServerName()}) @@ -693,6 +703,12 @@ func (psj *partialStateJoinResult) CreateMessageEvent(t *testing.T, senderLocalp return event } +// allow a /state_ids request for a given event to complete before FinishStateRequest has been called. +// only applies to new incoming requests, and not any currently blocked ones. +func (psj *partialStateJoinResult) AllowStateRequestForEvent(eventID string) { + psj.fedStateIdsAllowedEvents[eventID] = true +} + // wait for a /state_ids request for the test room to arrive func (psj *partialStateJoinResult) AwaitStateIdsRequest(t *testing.T) { psj.fedStateIdsRequestReceivedWaiter.Waitf(t, 5*time.Second, "Waiting for /state_ids request") @@ -709,7 +725,7 @@ func (psj *partialStateJoinResult) FinishStateRequest() { // if sendResponseWaiter is not nil, we will Wait() for it to finish before sending the response. func handleStateIdsRequests( t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom, - requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter, + requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter, allowedEvents map[string]bool, ) { srv.Mux().Handle( fmt.Sprintf("/_matrix/federation/v1/state_ids/%s", serverRoom.RoomID), @@ -719,7 +735,8 @@ func handleStateIdsRequests( if requestReceivedWaiter != nil { requestReceivedWaiter.Finish() } - if sendResponseWaiter != nil { + if !allowedEvents[queryParams["event_id"][0]] && + sendResponseWaiter != nil { sendResponseWaiter.Waitf(t, 60*time.Second, "Waiting for /state_ids request") } t.Logf("Replying to /state_ids request") @@ -744,7 +761,7 @@ func handleStateIdsRequests( // if sendResponseWaiter is not nil, we will Wait() for it to finish before sending the response. func handleStateRequests( t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom, - requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter, + requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter, allowedEvents map[string]bool, ) { srv.Mux().Handle( fmt.Sprintf("/_matrix/federation/v1/state/%s", serverRoom.RoomID), @@ -754,7 +771,8 @@ func handleStateRequests( if requestReceivedWaiter != nil { requestReceivedWaiter.Finish() } - if sendResponseWaiter != nil { + if !allowedEvents[queryParams["event_id"][0]] && + sendResponseWaiter != nil { sendResponseWaiter.Waitf(t, 60*time.Second, "Waiting for /state request") } res := gomatrixserverlib.RespState{