From c261c3b11792050814bce7d10e870ae066790c47 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Thu, 16 May 2024 11:38:35 -0700 Subject: [PATCH] add cyclic wfcontroller (#2554) --- .../executors/task_script_runner.py | 2 +- nvflare/app_common/workflows/base_fedavg.py | 24 -------- nvflare/app_common/workflows/cyclic.py | 60 +++++++++++++++++++ .../app_common/workflows/model_controller.py | 18 ++++++ nvflare/app_common/workflows/wf_controller.py | 11 ++++ 5 files changed, 90 insertions(+), 25 deletions(-) create mode 100644 nvflare/app_common/workflows/cyclic.py diff --git a/nvflare/app_common/executors/task_script_runner.py b/nvflare/app_common/executors/task_script_runner.py index bd563e2b43..053db6f8af 100644 --- a/nvflare/app_common/executors/task_script_runner.py +++ b/nvflare/app_common/executors/task_script_runner.py @@ -47,7 +47,7 @@ def __init__(self, site_name: str, script_path: str, script_args: str = None, re def run(self): """Call the task_fn with any required arguments.""" - self.logger.info(f"\n start task run() with full path: {self.script_full_path}") + self.logger.info(f"start task run() with full path: {self.script_full_path}") try: curr_argv = sys.argv builtins.print = log_print if self.redirect_print_to_log else print_fn diff --git a/nvflare/app_common/workflows/base_fedavg.py b/nvflare/app_common/workflows/base_fedavg.py index f691034cbe..1f9fd96021 100644 --- a/nvflare/app_common/workflows/base_fedavg.py +++ b/nvflare/app_common/workflows/base_fedavg.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random from typing import List from nvflare.apis.fl_constant import FLMetaKey @@ -48,7 +47,6 @@ def __init__( The model_persistor will also save the model after training. Provides the default implementations for the follow routines: - - def sample_clients(self, min_clients) - def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel - def update_model(self, aggr_result) @@ -74,28 +72,6 @@ def __init__( self.current_round = None - def sample_clients(self, num_clients): - """Called by the `run` routine to get a list of available clients. - - Args: - min_clients: number of clients to return. - - Returns: list of clients. - - """ - - clients = self.engine.get_clients() - - if num_clients <= len(clients): - random.shuffle(clients) - clients = clients[0:num_clients] - else: - self.info( - f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients." - ) - - return clients - @staticmethod def _check_results(results: List[FLModel]): empty_clients = [] diff --git a/nvflare/app_common/workflows/cyclic.py b/nvflare/app_common/workflows/cyclic.py new file mode 100644 index 0000000000..2ed67bd3ab --- /dev/null +++ b/nvflare/app_common/workflows/cyclic.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .wf_controller import WFController + + +class Cyclic(WFController): + def __init__( + self, + *args, + min_clients: int = 2, + num_rounds: int = 5, + start_round: int = 0, + **kwargs, + ): + """The Cyclic Workflow controller to implement the Cyclic Weight Transfer (CWT) algorithm. + + Args: + min_clients (int, optional): The minimum number of clients. Defaults to 2. + num_rounds (int, optional): The total number of training rounds. Defaults to 5. + start_round (int, optional): The starting round number. Defaults to 0 + """ + super().__init__(*args, **kwargs) + + self.min_clients = min_clients + self.num_rounds = num_rounds + self.start_round = start_round + self.current_round = None + + def run(self) -> None: + self.info("Start Cyclic.") + + model = self.load_model() + model.start_round = self.start_round + model.total_rounds = self.num_rounds + + for self.current_round in range(self.start_round, self.start_round + self.num_rounds): + self.info(f"Round {self.current_round} started.") + model.current_round = self.current_round + + clients = self.sample_clients(self.min_clients) + + for client in clients: + result = self.send_model_and_wait(targets=[client], data=model)[0] + model.params, model.meta = result.params, result.meta + + self.save_model(model) + + self.info("Finished Cyclic.") diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index 8d511b9805..02956755de 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -13,6 +13,7 @@ # limitations under the License. import gc +import random from abc import ABC, abstractmethod from typing import Callable, List, Union @@ -343,6 +344,23 @@ def save_model(self, model): else: self.error("persistor not configured, model will not be saved") + def sample_clients(self, num_clients): + clients = self.engine.get_clients() + + if num_clients < len(clients): + random.shuffle(clients) + clients = clients[0:num_clients] + self.info( + f"num_clients ({num_clients}) is less than the number of available clients. Returning a random subset of {num_clients} clients." + ) + elif num_clients > len(clients): + self.info( + f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients." + ) + self.info(f"Sampled clients: {[client.name for client in clients]}") + + return clients + def stop_controller(self, fl_ctx: FLContext): self.fl_ctx = fl_ctx self.finalize() diff --git a/nvflare/app_common/workflows/wf_controller.py b/nvflare/app_common/workflows/wf_controller.py index 021de51c10..36db78f871 100644 --- a/nvflare/app_common/workflows/wf_controller.py +++ b/nvflare/app_common/workflows/wf_controller.py @@ -116,3 +116,14 @@ def save_model(self, model: FLModel): None """ super().save_model(model) + + def sample_clients(self, num_clients): + """Returns a list of available clients. + + Args: + min_clients: number of clients to return. + + Returns: list of clients. + + """ + return super().sample_clients(num_clients)