From 0a5436ee3f231788fbff4493c30fc376b277512f Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Tue, 16 Jul 2024 11:25:25 -0700 Subject: [PATCH] add ModelController docs --- .../controllers/model_controller.rst | 230 ++++++++++++++++++ docs/programming_guide/fl_model.rst | 4 +- .../workflows_and_controllers.rst | 6 +- .../workflows/base_model_controller.py | 4 +- .../app_common/workflows/model_controller.py | 6 +- 5 files changed, 242 insertions(+), 8 deletions(-) create mode 100644 docs/programming_guide/controllers/model_controller.rst diff --git a/docs/programming_guide/controllers/model_controller.rst b/docs/programming_guide/controllers/model_controller.rst new file mode 100644 index 0000000000..cd8d107296 --- /dev/null +++ b/docs/programming_guide/controllers/model_controller.rst @@ -0,0 +1,230 @@ +.. _model_controller: + +################### +ModelController API +################### + +The FLARE :mod:`ModelController` API provides an easy way for users to write and customize FLModel-based controller workflows. + +* Highly flexible with a simple API (run routine and basic communication and utility functions) +* :ref:`fl_model`-based, everything else is pure Python +* Option to support pre-existing components and FLARE-specific functionalities + +Core Concepts +============= + +As an example, we can take a look at the popular federated learning workflow, "FedAvg" which has the following steps: + +#. FL server initializes an initial model +#. For each round (global iteration): + + #. FL server sends the global model to clients + #. Each FL client starts with this global model and trains on their own data + #. Each FL client sends back their trained model + #. FL server aggregates all the models and produces a new global model + + +To implement this workflow using the ModelController there are a few essential parts: + +* Import and subclass the :class:`nvflare.app_common.workflows.model_controller.ModelController`. +* Implement the ``run()`` routine for the workflow logic. +* Utilize ``send_model()`` / ``send_model_and_wait()`` for communication to send tasks with FLModel to target clients, and receive FLModel results. +* Customize workflow using predefined utility functions and components, or implement your own logics. + + +Here is an example of the FedAvg workflow using the :class:`BaseFedAvg` base class: + +.. code-block:: python + + # BaseFedAvg subclasses ModelController and defines common functions and variables such as aggregate(), update_model(), self.start_round, self.num_rounds + class FedAvg(BaseFedAvg): + + # run routine that user must implement + def run(self) -> None: + self.info("Start FedAvg.") + + # load model (by default uses persistor, can provide custom method) + model = self.load_model() + model.start_round = self.start_round + model.total_rounds = self.num_rounds + + # for each round (global iteration) + for self.current_round in range(self.start_round, self.start_round + self.num_rounds): + self.info(f"Round {self.current_round} started.") + model.current_round = self.current_round + + # obtain self.num_clients clients + clients = self.sample_clients(self.num_clients) + + # send model to target clients with default train task, wait to receive results + results = self.send_model_and_wait(targets=clients, data=model) + + # use BaseFedAvg aggregate function + aggregate_results = self.aggregate( + results, aggregate_fn=self.aggregate_fn + ) # using default aggregate_fn with `WeightedAggregationHelper`. Can overwrite self.aggregate_fn with signature Callable[List[FLModel], FLModel] + + # update global model with agggregation results + model = self.update_model(model, aggregate_results) + + # save model (by default uses persistor, can provide custom method) + self.save_model(model) + + self.info("Finished FedAvg.") + + +Below is a comprehensive table overview of the :class:`ModelController` API: + + +.. list-table:: ModelController API + :widths: 25 35 50 + :header-rows: 1 + + * - API + - Description + - API Doc Link + * - run + - Run routine for workflow. + - :func:`run` + * - send_model_and_wait + - Send a task with data to targets (blocking) and wait for results.. + - :func:`send_model_and_wait` + * - send_model + - Send a task with data to targets (non-blocking) with callback. + - :func:`send_model` + * - sample_clients + - Returns a list of num_clients clients. + - :func:`sample_clients` + * - save_model + - Save model with persistor. + - :func:`save_model` + * - load_model + - Load model from persistor. + - :func:`load_model` + + +Communication +============= + +send_model_and_wait +------------------- +:func:`send_model_and_wait` is the core communication function which enables users to send tasks to targets, and wait for responses. + +The ``data`` is an :ref:`fl_model` object, and the ``task_name`` is the task for the target executors to execute (Client API executors by default support "train", "validate", and "submit_model", however executors can be written for any arbitrary task name). + +``targets`` can be chosen from client names obtained with ``sample_clients()``. + +Returns the :ref:`fl_model` responses from the target clients once the task is completed (``min_responses`` have been received, or ``timeout`` time has passed). + +send_model +---------- +:func:`send_model` is the non-blocking version of +:func:`send_model_and_wait` with a user-defined callback when receiving responses. + +A callback with the signature ``Callable[[FLModel], None]`` can be passed in, which will be called when a response is received from each target. + +The task is standing until either ``min_responses`` have been received, or ``timeout`` time has passed. +Since this call is asynchronous, the Controller :func:`get_num_standing_tasks` method can be used to get the number of standing tasks for synchronization purposes. + + +Saving & Loading +================ + +persistor +--------- +The :func:`save_model` and :func:`load_model` +functions utilize the configured persistor set in the ModelController ``persistor_id: str = "persistor"`` argument. + +For the JobAPI, the persistor with ``id = "persistor"`` will automatically be configured based on the type of the model sent to the server. + +.. code-block:: python + + job.to(Net(), "server") + +The persistor can also be configured in ``config_fed_server.json`` in the components section. + +custom save & load +------------------ +Users can also choose to instead create their own save and load functions rather than use a persistor. + +For example we can use PyTorch's save and load functions for the model parameters, and save the FLModel metadata with :mod:`FOBS` separately: + +.. code-block:: python + + import torch + from nvflare.fuel.utils import fobs + + def save_model(self, model, filepath=""): + params = model.params + # PyTorch save + torch.save(params, filepath) + + # save FLModel metadata + model.params = {} + fobs.dumpf(model, filepath + ".metadata") + model.params = params + + def load_model(self, filepath=""): + # PyTorch load + params = torch.load(filepath) + + # load FLModel metadata + model = fobs.loadf(filepath + ".metadata") + model.params = params + return model + + +Note: for non-primitive data types such as ``torch.nn.Module`` (used for the initial PyTorch model), we must configure a corresponding FOBS decomposer for serialization and deserialization. +Read more at :github_nvflare_link:`Flare Object Serializer (FOBS) `. + +.. code-block:: python + + from nvflare.app_opt.pt.decomposers import TensorDecomposer + + fobs.register(TensorDecomposer) + + +Additional Functionalities +========================== + +In some cases, more advanced FLARE-specific functionalities may be of use. + +The :mod:`BaseModelController` class provides access to the engine ``self.engine`` and FLContext ``self.fl_ctx`` if needed. +Functions such as ``get_component()`` and ``build_component()`` can be used to load or dynamically build components. + +Furthermore, the underlying :mod:`Controller` class offers additional communication functions and task related utilities. +Many of our pre-existing workflows are based on this lower-level Controller API. +For more details refer to the :ref:`controllers` section. + +Configuration +============= + +For the JobAPI, define the controller and send it to the server. + +.. code-block:: python + + controller = FedAvg( + num_clients=n_clients, + num_rounds=num_rounds, + ) + job.to(controller, "server") + +The controller can also be configured in ``config_fed_server.json`` in the workflows section. + +Examples +======== + +Examples of basic workflows using the ModelController API: + +* :github_nvflare_link:`Cyclic ` +* :github_nvflare_link:`BaseFedAvg ` +* :github_nvflare_link:`FedAvg ` + +Advanced examples: + +* :github_nvflare_link:`Scaffold ` +* :github_nvflare_link:`FedOpt ` +* :github_nvflare_link:`PTFedAvgEarlyStopping ` +* :github_nvflare_link:`Kaplan-Meier ` +* :github_nvflare_link:`Logistic Regression Newton Raphson ` +* :github_nvflare_link:`FedBPT ` diff --git a/docs/programming_guide/fl_model.rst b/docs/programming_guide/fl_model.rst index 6b2a9bad07..702af3a4de 100644 --- a/docs/programming_guide/fl_model.rst +++ b/docs/programming_guide/fl_model.rst @@ -3,7 +3,7 @@ FLModel ======= -We define a standard data structure :mod:`FLModel` +We define a standard data structure :mod:`FLModel` that captures the common attributes needed for exchanging learning results. This is particularly useful when NVFlare system needs to exchange learning @@ -14,4 +14,4 @@ information from received FLModel, run local training, and put the results in a new FLModel to be sent back. For a detailed explanation of each attributes, please refer to the API doc: -:mod:`FLModel` +:mod:`FLModel` diff --git a/docs/programming_guide/workflows_and_controllers.rst b/docs/programming_guide/workflows_and_controllers.rst index 9a75c9901d..9981400019 100644 --- a/docs/programming_guide/workflows_and_controllers.rst +++ b/docs/programming_guide/workflows_and_controllers.rst @@ -7,7 +7,10 @@ A workflow has one or more controllers, each implementing a specific coordinatio CrossSiteValidation controller implements a strategy to let every client site evaluate every other site's model. You can put together a workflow that uses any number of controllers. -We have implemented several server controlled federated learning workflows (fed-average, cyclic controller, cross-site evaluation) with the server-side :ref:`controllers `. +We provide the FLModel-based :ref:`model_controller` which provides a straightforward way for users to write controllers. +Additionally we also have the lower level :ref:`Controller API ` with more FLARE-specific concepts, which many of our existing workflows are based upon. + +We have implemented several server controlled federated learning workflows (fed-average, cyclic controller, cross-site evaluation) with the server-side controllers. In these workflows, FL clients get tasks assigned by the controller, execute the tasks, and submit results back to the server. In certain cases, if the server cannot be trusted, it should not be involved in communication with sensitive information. @@ -18,5 +21,6 @@ Please refer to the following sections for more details. .. toctree:: :maxdepth: 3 + controllers/model_controller controllers/controllers controllers/client_controlled_workflows diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index f0b865ac8b..3310a5014b 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -351,7 +351,7 @@ def save_model(self, model): self.error("persistor not configured, model will not be saved") def sample_clients(self, num_clients=None): - clients = self.engine.get_clients() + clients = [client.name for client in self.engine.get_clients()] if num_clients: check_positive_int("num_clients", num_clients) @@ -366,7 +366,7 @@ def sample_clients(self, num_clients=None): f"num_clients ({num_clients}) is greater than the number of available clients. Returning all ({len(clients)}) available clients." ) - self.info(f"Sampled clients: {[client.name for client in clients]}") + self.info(f"Sampled clients: {clients}") return clients diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index bd46595d14..125a6ffc97 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -103,7 +103,7 @@ def send_model( callback=callback, ) - def load_model(self): + def load_model(self) -> FLModel: """Load initial model from persistor. If persistor is not configured, returns empty FLModel. Returns: @@ -111,7 +111,7 @@ def load_model(self): """ return super().load_model() - def save_model(self, model: FLModel): + def save_model(self, model: FLModel) -> None: """Saves model with persistor. If persistor is not configured, does not save. Args: @@ -122,7 +122,7 @@ def save_model(self, model: FLModel): """ super().save_model(model) - def sample_clients(self, num_clients=None): + def sample_clients(self, num_clients: int = None) -> List[str]: """Returns a list of `num_clients` clients. Args: