Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

Commit

Permalink
PB-407 add more debug level logging
Browse files Browse the repository at this point in the history
  • Loading branch information
janpetschexain committed Feb 24, 2020
1 parent cadcec5 commit 502bac0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
9 changes: 5 additions & 4 deletions tests/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,9 @@ def test_message_loop(mock_heartbeat_request, _mock_sleep, _mock_event):
channel = mock.MagicMock()
terminate_event = threading.Event()
state_record = StateRecord()
participant_id = "123"

message_loop(channel, state_record, terminate_event)
message_loop(channel, participant_id, state_record, terminate_event)

# check that the heartbeat is sent exactly twice
expected_call = mock.call(round=-1, state=State.READY)
Expand Down Expand Up @@ -293,9 +294,9 @@ def test_start_training_round(coordinator_service):
# simulate a participant communicating with coordinator via channel
with grpc.insecure_channel("localhost:50051") as channel:
# we need to rendezvous before we can send any other requests
rendezvous(channel)
rendezvous(channel, participant_id="123")
# call StartTrainingRound service method on coordinator
epochs, epoch_base = start_training_round(channel)
epochs, epoch_base = start_training_round(channel, participant_id="123")

# check global model received
assert epochs == 5
Expand Down Expand Up @@ -364,7 +365,7 @@ def test_end_training_round(

with grpc.insecure_channel("localhost:50051") as channel:
# we first need to rendezvous before we can send any other request
rendezvous(channel)
rendezvous(channel, participant_id="123")
# call EndTrainingRound service method on coordinator
participant_store.write_weights("participant1", 0, test_weights)
end_training_round(
Expand Down
3 changes: 2 additions & 1 deletion xain_fl/coordinator/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def select_outstanding(self) -> List[str]:
frac = num_outstanding / len(pool)

self.controller.fraction_of_participants = frac
return self.controller.select_ids(list(pool))
outstanding: List[str] = self.controller.select_ids(list(pool))
return outstanding

def _handle_rendezvous(
self, _message: RendezvousRequest, participant_id: str
Expand Down

0 comments on commit 502bac0

Please sign in to comment.