Skip to content

Commit

Permalink
add cyclic wfcontroller
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster committed May 3, 2024
1 parent 775880f commit 64c1687
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 24 deletions.
2 changes: 1 addition & 1 deletion nvflare/app_common/executors/task_script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, script_path: str, script_args: str = None):

def run(self):
"""Call the task_fn with any required arguments."""
self.logger.info(f"\n start task run() with {self.script_path}")
self.logger.info(f"start task run() with {self.script_path}")
try:
import runpy

Expand Down
23 changes: 0 additions & 23 deletions nvflare/app_common/workflows/base_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(
The model_persistor will also save the model after training.
Provides the default implementations for the follow routines:
- def sample_clients(self, min_clients)
- def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel
- def update_model(self, aggr_result)
Expand All @@ -74,28 +73,6 @@ def __init__(

self.current_round = None

def sample_clients(self, num_clients):
"""Called by the `run` routine to get a list of available clients.
Args:
min_clients: number of clients to return.
Returns: list of clients.
"""

clients = self.engine.get_clients()

if num_clients <= len(clients):
random.shuffle(clients)
clients = clients[0:num_clients]
else:
self.info(
f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients."
)

return clients

@staticmethod
def _check_results(results: List[FLModel]):
empty_clients = []
Expand Down
60 changes: 60 additions & 0 deletions nvflare/app_common/workflows/cyclic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .wf_controller import WFController


class Cyclic(WFController):
def __init__(
self,
*args,
min_clients: int = 2,
num_rounds: int = 5,
start_round: int = 0,
**kwargs,
):
"""The Cyclic Workflow controller to implement the Cyclic Weight Transfer (CWT) algorithm.
Args:
min_clients (int, optional): The minimum number of clients. Defaults to 2.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): The starting round number. Defaults to 0
"""
super().__init__(*args, **kwargs)

self.min_clients = min_clients
self.num_rounds = num_rounds
self.start_round = start_round
self.current_round = None

def run(self) -> None:
self.info("Start Cyclic.")

model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds

for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")
model.current_round = self.current_round

clients = self.sample_clients(self.min_clients)

for client in clients:
result = self.send_model_and_wait(targets=[client], data=model)[0]
model.params, model.meta = result.params, result.meta

self.save_model(model)

self.info("Finished Cyclic.")
19 changes: 19 additions & 0 deletions nvflare/app_common/workflows/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import gc
import random
from abc import ABC, abstractmethod
from typing import Callable, List, Union

Expand Down Expand Up @@ -165,6 +166,7 @@ def broadcast_model(
# de-reference the internal results before returning
results = self._results
self._results = []

return results
else:
self.broadcast(
Expand Down Expand Up @@ -343,6 +345,23 @@ def save_model(self, model):
else:
self.error("persistor not configured, model will not be saved")

def sample_clients(self, num_clients):
clients = self.engine.get_clients()

if num_clients < len(clients):
random.shuffle(clients)
clients = clients[0:num_clients]
self.info(
f"num_clients ({num_clients}) is less than the number of available clients. Returning a random subset of {num_clients} clients."
)
elif num_clients > len(clients):
self.info(
f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients."
)
self.info(f"Sampled clients: {[client.name for client in clients]}")

return clients

def stop_controller(self, fl_ctx: FLContext):
self.fl_ctx = fl_ctx
self.finalize()
11 changes: 11 additions & 0 deletions nvflare/app_common/workflows/wf_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,14 @@ def save_model(self, model: FLModel):
None
"""
super().save_model(model)

def sample_clients(self, num_clients):
"""Returns a list of available clients.
Args:
min_clients: number of clients to return.
Returns: list of clients.
"""
return super().sample_clients(num_clients)

0 comments on commit 64c1687

Please sign in to comment.