Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add controller interface #2451

Merged
merged 3 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions nvflare/apis/wf_controller_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Callable, List, Union


class WFControllerSpec(ABC):
@abstractmethod
def run(self):
"""Main `run` routine for the controller workflow."""
raise NotImplementedError

def send_model(
self,
task_name: str,
data: any,
YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved
targets: Union[List[str], None],
timeout: int,
wait_time_after_min_received: int,
blocking: bool,
callback: Callable,
) -> List:
"""Send a task with data to a list of targets.

Args:
task_name (str): name of the task.
data (any): data to be sent to clients.
targets (List[str]): the list of target client names.
timeout (int): time to wait for clients to perform task.
wait_time_after_min_received (int): time to wait after
minimum number of clients responses has been received.
blocking (bool): whether to block to wait for task result.
callback (Callable[any]): callback when a result is received, only called when blocking=False.

Returns:
List[any] if blocking=True else None
"""
raise NotImplementedError
1 change: 1 addition & 0 deletions nvflare/app_common/app_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class AppConstants(object):
TASK_END_RUN = "_end_run_"
TASK_TRAIN = "train"
TASK_GET_WEIGHTS = "get_weights"
TASK_PROP_CALLBACK = "_task_callback_"

DEFAULT_AGGREGATOR_ID = "aggregator"
DEFAULT_PERSISTOR_ID = "persistor"
Expand Down
3 changes: 2 additions & 1 deletion nvflare/app_common/workflows/base_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List

from nvflare.apis.fl_constant import FLMetaKey
from nvflare.apis.wf_controller_spec import WFControllerSpec
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.app_common.aggregators.weighted_aggregation_helper import WeightedAggregationHelper
Expand All @@ -27,7 +28,7 @@
from .model_controller import ModelController


class BaseFedAvg(ModelController):
class BaseFedAvg(ModelController, WFControllerSpec):
SYangster marked this conversation as resolved.
Show resolved Hide resolved
"""The base controller for FedAvg Workflow. *Note*: This class is based on the experimental `ModelController`.

Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_common/workflows/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def run(self) -> None:

clients = self.sample_clients(self._min_clients)

results = self.send_model_and_wait(targets=clients, data=self.model)
results = self.send_model(targets=clients, data=self.model)

aggregate_results = self.aggregate(
results, aggregate_fn=None
Expand Down
127 changes: 76 additions & 51 deletions nvflare/app_common/workflows/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import List, Union
from abc import ABC
from typing import Callable, List, Union

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, OperatorMethod, Task, TaskOperatorKey
Expand All @@ -22,6 +22,7 @@
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.apis.wf_controller_spec import WFControllerSpec
from nvflare.app_common.abstract.fl_model import FLModel, ParamsType
from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor
from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable
Expand All @@ -36,7 +37,7 @@


@experimental
class ModelController(Controller, FLComponentWrapper):
class ModelController(Controller, FLComponentWrapper, WFControllerSpec, ABC):
SYangster marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
min_clients: int = 1000,
Expand Down Expand Up @@ -157,33 +158,82 @@ def _build_shareable(self, data: FLModel = None) -> Shareable:

return data_shareable

def send_model_and_wait(
def send_model(
self,
targets: Union[List[Client], List[str], None] = None,
data: FLModel = None,
task_name=AppConstants.TASK_TRAIN,
data: FLModel = None,
targets: Union[List[Client], List[str], None] = None,
timeout: int = 0,
wait_time_after_min_received: int = 10,
blocking: bool = True,
callback: Callable[[FLModel], None] = None,
) -> List:
"""Send the current global model or given data to a list of targets

The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients.
"""Send a task with data to a list of targets.

Args:
targets: the list of eligible clients or client names or None (all clients). Defaults to None.
data: FLModel to be sent to clients. If no data is given, send `self.model`.
task_name (str, optional): Name of the train task. Defaults to "train".
timeout (int, optional): Time to wait for clients to do local training. Defaults to 0, i.e., never time out.
wait_time_after_min_received (int, optional): Time to wait before beginning aggregation after
task_name (str, optional): name of the task. Defaults to "train".
data (FLModel): FLModel to be sent to clients. If no data is given, send `self.model`.
targets (List[str]): the list of target client names or None (all clients). Defaults to None.
timeout (int, optional): time to wait for clients to perform task. Defaults to 0, i.e., never time out.
wait_time_after_min_received (int, optional): time to wait after
minimum number of clients responses has been received. Defaults to 10.
blocking (bool): whether to block to wait for task result.
callback (Callable[FLModel]): callback when a result is received, only called when blocking=False.

Returns:
List[FLModel] if blocking=True else None
"""

if not isinstance(task_name, str):
raise TypeError("train_task_name must be a string but got {}".format(type(task_name)))
raise TypeError("task_name must be a string but got {}".format(type(task_name)))
check_non_negative_int("timeout", timeout)
check_non_negative_int("wait_time_after_min_received", wait_time_after_min_received)

# Create train_task
if targets:
targets = [client.name if isinstance(client, Client) else client for client in targets]

task = self._prepare_task(data=data, task_name=task_name, timeout=timeout, callback=callback)

if blocking:
self._results = [] # reset results list
self.info(f"Sending task {task_name} to {targets}")
self.broadcast_and_wait(
task=task,
targets=targets,
min_responses=self._min_clients,
wait_time_after_min_received=wait_time_after_min_received,
fl_ctx=self.fl_ctx,
abort_signal=self.abort_signal,
)

if targets is not None:
if len(self._results) != self._min_clients:
self.warning(
f"Number of results ({len(self._results)}) is different from min_clients ({self._min_clients})."
)

# de-refernce the internel results before returning
results = self._results
self._results = []
return results
else:
self.info(f"Sending task {task_name} to {targets}")
self.broadcast(
task=task,
targets=targets,
min_responses=self._min_clients,
wait_time_after_min_received=wait_time_after_min_received,
fl_ctx=self.fl_ctx,
)

def _prepare_task(
self,
data: FLModel,
task_name: str,
timeout: int,
callback: Callable,
):
# Create task
data_shareable = self._build_shareable(data)

operator = {
Expand All @@ -192,37 +242,17 @@ def send_model_and_wait(
TaskOperatorKey.TIMEOUT: timeout,
}

train_task = Task(
task = Task(
name=task_name,
data=data_shareable,
operator=operator,
props={},
props={AppConstants.TASK_PROP_CALLBACK: callback},
timeout=timeout,
before_task_sent_cb=self._prepare_task_data,
result_received_cb=self._process_result,
)

self._results = [] # reset results list
self.info(f"Sending task {task_name} to {[client.name for client in targets]}")
self.broadcast_and_wait(
task=train_task,
targets=targets,
min_responses=self._min_clients,
wait_time_after_min_received=wait_time_after_min_received,
fl_ctx=self.fl_ctx,
abort_signal=self.abort_signal,
)

if targets is not None:
if len(self._results) != self._min_clients:
self.warning(
f"Number of results ({len(self._results)}) is different from min_clients ({self._min_clients})."
)

# de-refernce the internel results before returning
results = self._results
self._results = []
return results
return task

def _prepare_task_data(self, client_task: ClientTask, fl_ctx: FLContext) -> None:
fl_ctx.set_prop(AppConstants.TRAIN_SHAREABLE, client_task.task.data, private=True, sticky=False)
Expand All @@ -245,10 +275,14 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None:
result_model.meta["current_round"] = self._current_round
result_model.meta["total_rounds"] = self._num_rounds

self._results.append(result_model)
callback = client_task.task.get_prop(AppConstants.TASK_PROP_CALLBACK)
if callback:
callback(result_model)
else:
self._results.append(result_model)

# Cleanup task result
client_task.result = None
# Cleanup task result
client_task.result = None

def process_result_of_unknown_task(
self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext
Expand Down Expand Up @@ -278,15 +312,6 @@ def _accept_train_result(self, client_name: str, result: Shareable, fl_ctx: FLCo

self.fl_ctx.set_prop(AppConstants.TRAINING_RESULT, result, private=True, sticky=False)

@abstractmethod
def run(self):
"""Main `run` routine called by the Controller's `control_flow` to execute the workflow.

Returns: None.

"""
raise NotImplementedError

def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None:
self._phase = AppConstants.PHASE_TRAIN
fl_ctx.set_prop(AppConstants.PHASE, self._phase, private=True, sticky=False)
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_common/workflows/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def run(self) -> None:
global_model = self.model
global_model.meta[AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL] = self._global_ctrl_weights

results = self.send_model_and_wait(targets=clients, data=global_model)
results = self.send_model(targets=clients, data=global_model)

aggregate_results = self.aggregate(results, aggregate_fn=scaffold_aggregate_fn)

Expand Down
Loading