From 48a87adf39dad53550329da64f87b52085a26105 Mon Sep 17 00:00:00 2001 From: finiteprods Date: Tue, 11 Feb 2020 08:17:27 +0100 Subject: [PATCH 1/4] PB-398 add/remove selected in round --- xain_fl/coordinator/round.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/xain_fl/coordinator/round.py b/xain_fl/coordinator/round.py index d4fcfa836..cede5b677 100644 --- a/xain_fl/coordinator/round.py +++ b/xain_fl/coordinator/round.py @@ -22,6 +22,26 @@ def __init__(self, participant_ids: List[str]) -> None: self.participant_ids = participant_ids self.updates: Dict[str, Dict] = {} + def add_selected(self, more_ids: List[str]) -> None: + """Add to the collection of selected participants. + + Args: + more_ids: ids of participants to add. + """ + + self.participant_ids.extend(more_ids) + + def remove_selected(self, participant_id: str) -> None: + """Remove from the collection of selected participants. + + Args: + participant_id: id of participant to remove. + """ + + self.participant_ids = [ + id for id in self.participant_ids if id != participant_id + ] + def add_updates( self, participant_id: str, @@ -29,7 +49,7 @@ def add_updates( aggregation_data: int, metrics: Dict[str, ndarray], ) -> None: - """Valid a participant's update for the round. + """Add a participant's update for the round. Args: participant_id: The id of the participant making the request. @@ -65,7 +85,7 @@ def is_finished(self) -> bool: round. `False` otherwise. """ - return len(self.updates) == len(self.participant_ids) + return all(id in self.updates for id in self.participant_ids) def get_weight_updates(self) -> Tuple[List[ndarray], List[int]]: """Get a list of all participants weight updates. @@ -75,7 +95,8 @@ def get_weight_updates(self) -> Tuple[List[ndarray], List[int]]: The lists of model weights and aggregation meta data from all participants. """ + updates = [self.updates[id] for id in self.participant_ids] return ( - [v["model_weights"] for v in self.updates.values()], - [v["aggregation_data"] for v in self.updates.values()], + [upd["model_weights"] for upd in updates], + [upd["aggregation_data"] for upd in updates], ) From 5115390c2b8591977242aa6d732a5bf72eb61322 Mon Sep 17 00:00:00 2001 From: finiteprods Date: Tue, 11 Feb 2020 08:18:52 +0100 Subject: [PATCH 2/4] PB-398 order controller for selecting init segment of sorted --- xain_fl/fl/coordinator/controller.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/xain_fl/fl/coordinator/controller.py b/xain_fl/fl/coordinator/controller.py index 470b88679..df199cdfc 100644 --- a/xain_fl/fl/coordinator/controller.py +++ b/xain_fl/fl/coordinator/controller.py @@ -59,7 +59,7 @@ def select_ids(self, participant_ids: List[str]) -> List[str]: class IdController(Controller): - """[summary + """[summary] ... todo: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) """ @@ -78,6 +78,28 @@ def select_ids(self, participant_ids: List[str]) -> List[str]: return participant_ids +class OrderController(Controller): + """[summary] + + ... todo: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) + """ + + def select_ids(self, participant_ids: List[str]) -> List[str]: + """Selects participants according to order. + + Args: + participant_ids (:obj:`list` of :obj:`str`): The list of IDs of all the + available participants, a subset of which will be selected. + + Returns: + :obj:`list` of :obj:`str`: List of selected participant IDs + """ + + num_ids_to_select = self.get_num_ids_to_select(len(participant_ids)) + sorted_ids = sorted(participant_ids) + return sorted_ids[:num_ids_to_select] + + class RandomController(Controller): """[summary] From 57a9f6b9f42739e5b35bcde3b014bb317049ea4f Mon Sep 17 00:00:00 2001 From: finiteprods Date: Tue, 11 Feb 2020 08:28:38 +0100 Subject: [PATCH 3/4] PB-398 split test_remove_participants; test_select_outstanding --- tests/test_coordinator.py | 85 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 25973374c..203b940b3 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -16,6 +16,7 @@ from xain_proto.np import proto_to_ndarray from xain_fl.coordinator.coordinator import Coordinator +from xain_fl.fl.coordinator.controller import OrderController from xain_fl.tools.exceptions import ( DuplicatedUpdateError, InvalidRequestError, @@ -268,7 +269,7 @@ def test_duplicated_update_submit(): coordinator.on_message(EndTrainingRoundRequest(), "participant1") -def test_remove_participant(): +def test_remove_selected_participant(): """[summary] .. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) @@ -279,18 +280,47 @@ def test_remove_participant(): ) coordinator.on_message(RendezvousRequest(), "participant1") + assert coordinator.participants.len() == 1 + assert coordinator.round.participant_ids == ["participant1"] assert coordinator.state == State.ROUND coordinator.remove_participant("participant1") assert coordinator.participants.len() == 0 + assert coordinator.round.participant_ids == [] assert coordinator.state == State.STANDBY coordinator.on_message(RendezvousRequest(), "participant1") + assert coordinator.participants.len() == 1 + assert coordinator.round.participant_ids == ["participant1"] assert coordinator.state == State.ROUND +def test_remove_unselected_participant(): + """[summary] + + .. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) + """ + + coordinator = Coordinator( + minimum_participants_in_round=1, fraction_of_participants=0.5 + ) + coordinator.on_message(RendezvousRequest(), "participant1") + coordinator.on_message(RendezvousRequest(), "participant2") + + assert coordinator.participants.len() == 2 + assert len(coordinator.round.participant_ids) == 1 + + # override selection + coordinator.round.participant_ids = ["participant1"] + + coordinator.remove_participant("participant2") + + assert coordinator.participants.len() == 1 + assert coordinator.round.participant_ids == ["participant1"] + + def test_number_of_selected_participants(): """[summary] @@ -345,3 +375,56 @@ def test_correct_round_advertised_to_participants(): # state STANDBY will be advertised to participant2 (which has NOT been selected) response = coordinator.on_message(HeartbeatRequest(), "participant2") assert response.state == State.STANDBY + + +def test_select_outstanding(): + """[summary] + + .. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) + """ + + # setup: select first 3 of 4 in order per round + coordinator = Coordinator( + minimum_participants_in_round=3, + fraction_of_participants=0.75, + controller=OrderController(), + ) + coordinator.on_message(RendezvousRequest(), "participant1") + coordinator.on_message(RendezvousRequest(), "participant2") + coordinator.on_message(RendezvousRequest(), "participant3") + coordinator.on_message(RendezvousRequest(), "participant4") + + # 4 connected hence round starts + assert coordinator.state == State.ROUND + assert coordinator.participants.len() == 4 + # selection is triggered: order-controller guarantees it's [P1, P2, P3] + assert coordinator.round.participant_ids == [ + "participant1", + "participant2", + "participant3", + ] + + coordinator.remove_participant("participant3") + + # round pauses + assert coordinator.state == State.STANDBY + assert coordinator.participants.len() == 3 + assert coordinator.round.participant_ids == ["participant1", "participant2"] + + coordinator.remove_participant("participant1") + + assert coordinator.participants.len() == 2 + assert coordinator.round.participant_ids == ["participant2"] + + coordinator.on_message(RendezvousRequest(), "participant5") + coordinator.on_message(RendezvousRequest(), "participant6") + + # back up to 4 (P2, P4, P5, P6) so round resumes + assert coordinator.state == State.ROUND + assert coordinator.participants.len() == 4 + # selection triggered: P2 still selected with 2 outstanding from [P4, P5, P6] + assert coordinator.round.participant_ids == [ + "participant2", + "participant4", + "participant5", + ] From 4d5ad72cf1ee1a30016d0a4f92c131f9fb7abcdb Mon Sep 17 00:00:00 2001 From: finiteprods Date: Tue, 11 Feb 2020 08:30:47 +0100 Subject: [PATCH 4/4] PB-398 fix remove_participant; fix _handle_rendezvous --- xain_fl/coordinator/coordinator.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/xain_fl/coordinator/coordinator.py b/xain_fl/coordinator/coordinator.py index 8b0876718..dbc0f14a9 100644 --- a/xain_fl/coordinator/coordinator.py +++ b/xain_fl/coordinator/coordinator.py @@ -223,8 +223,10 @@ def remove_participant(self, participant_id: str) -> None: participant_id: The id of the participant to remove. """ - self.participants.remove(participant_id) logger.info("Removing participant", participant_id=participant_id) + self.participants.remove(participant_id) + # remove from selected if necessary + self.round.remove_selected(participant_id) if self.participants.len() < self.minimum_connected_participants: self.state = State.STANDBY @@ -236,6 +238,17 @@ def select_participant_ids_and_init_round(self) -> None: selected_ids = self.controller.select_ids(self.participants.ids()) self.round = Round(selected_ids) + def select_outstanding(self) -> List[str]: + """Selects participant ids.""" + + selected = set(self.round.participant_ids) + num_outstanding = self.minimum_participants_in_round - len(selected) + pool = set(self.participants.ids()) - selected + frac = num_outstanding / len(pool) + + self.controller.fraction_of_participants = frac + return self.controller.select_ids(list(pool)) + def _handle_rendezvous( self, _message: RendezvousRequest, participant_id: str ) -> RendezvousResponse: @@ -261,7 +274,10 @@ def _handle_rendezvous( # Select participants and change the state to ROUND if the latest added participant # lets us meet the minimum number of connected participants if self.participants.len() == self.minimum_connected_participants: - self.select_participant_ids_and_init_round() + # select enough to fill round if needed + if len(self.round.participant_ids) < self.minimum_participants_in_round: + ids = self.select_outstanding() + self.round.add_selected(ids) self.state = State.ROUND else: