Skip to content

Commit

Permalink
add ModelController docs
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster committed Jul 16, 2024
1 parent ed40b45 commit 0a5436e
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 8 deletions.
230 changes: 230 additions & 0 deletions docs/programming_guide/controllers/model_controller.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
.. _model_controller:

###################
ModelController API
###################

The FLARE :mod:`ModelController<nvflare.app_common.workflows.model_controller>` API provides an easy way for users to write and customize FLModel-based controller workflows.

* Highly flexible with a simple API (run routine and basic communication and utility functions)
* :ref:`fl_model`-based, everything else is pure Python
* Option to support pre-existing components and FLARE-specific functionalities

Core Concepts
=============

As an example, we can take a look at the popular federated learning workflow, "FedAvg" which has the following steps:

#. FL server initializes an initial model
#. For each round (global iteration):

#. FL server sends the global model to clients
#. Each FL client starts with this global model and trains on their own data
#. Each FL client sends back their trained model
#. FL server aggregates all the models and produces a new global model


To implement this workflow using the ModelController there are a few essential parts:

* Import and subclass the :class:`nvflare.app_common.workflows.model_controller.ModelController`.
* Implement the ``run()`` routine for the workflow logic.
* Utilize ``send_model()`` / ``send_model_and_wait()`` for communication to send tasks with FLModel to target clients, and receive FLModel results.
* Customize workflow using predefined utility functions and components, or implement your own logics.


Here is an example of the FedAvg workflow using the :class:`BaseFedAvg<nvflare.app_common.workflows.base_fedavg.BaseFedAvg>` base class:

.. code-block:: python
# BaseFedAvg subclasses ModelController and defines common functions and variables such as aggregate(), update_model(), self.start_round, self.num_rounds
class FedAvg(BaseFedAvg):
# run routine that user must implement
def run(self) -> None:
self.info("Start FedAvg.")
# load model (by default uses persistor, can provide custom method)
model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds
# for each round (global iteration)
for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")
model.current_round = self.current_round
# obtain self.num_clients clients
clients = self.sample_clients(self.num_clients)
# send model to target clients with default train task, wait to receive results
results = self.send_model_and_wait(targets=clients, data=model)
# use BaseFedAvg aggregate function
aggregate_results = self.aggregate(
results, aggregate_fn=self.aggregate_fn
) # using default aggregate_fn with `WeightedAggregationHelper`. Can overwrite self.aggregate_fn with signature Callable[List[FLModel], FLModel]
# update global model with agggregation results
model = self.update_model(model, aggregate_results)
# save model (by default uses persistor, can provide custom method)
self.save_model(model)
self.info("Finished FedAvg.")
Below is a comprehensive table overview of the :class:`ModelController<nvflare.app_common.workflows.model_controller.ModelController>` API:


.. list-table:: ModelController API
:widths: 25 35 50
:header-rows: 1

* - API
- Description
- API Doc Link
* - run
- Run routine for workflow.
- :func:`run<nvflare.app_common.workflows.model_controller.ModelController.run>`
* - send_model_and_wait
- Send a task with data to targets (blocking) and wait for results..
- :func:`send_model_and_wait<nvflare.app_common.workflows.model_controller.ModelController.send_model_and_wait>`
* - send_model
- Send a task with data to targets (non-blocking) with callback.
- :func:`send_model<nvflare.app_common.workflows.model_controller.ModelController.send_model>`
* - sample_clients
- Returns a list of num_clients clients.
- :func:`sample_clients<nvflare.app_common.workflows.model_controller.ModelController.sample_clients>`
* - save_model
- Save model with persistor.
- :func:`save_model<nvflare.app_common.workflows.model_controller.ModelController.save_model>`
* - load_model
- Load model from persistor.
- :func:`load_model<nvflare.app_common.workflows.model_controller.ModelController.load_model>`


Communication
=============

send_model_and_wait
-------------------
:func:`send_model_and_wait<nvflare.app_common.workflows.model_controller.ModelController.send_model_and_wait>` is the core communication function which enables users to send tasks to targets, and wait for responses.

The ``data`` is an :ref:`fl_model` object, and the ``task_name`` is the task for the target executors to execute (Client API executors by default support "train", "validate", and "submit_model", however executors can be written for any arbitrary task name).

``targets`` can be chosen from client names obtained with ``sample_clients()``.

Returns the :ref:`fl_model` responses from the target clients once the task is completed (``min_responses`` have been received, or ``timeout`` time has passed).

send_model
----------
:func:`send_model<nvflare.app_common.workflows.model_controller.ModelController.send_model>` is the non-blocking version of
:func:`send_model_and_wait<nvflare.app_common.workflows.model_controller.ModelController.send_model_and_wait>` with a user-defined callback when receiving responses.

A callback with the signature ``Callable[[FLModel], None]`` can be passed in, which will be called when a response is received from each target.

The task is standing until either ``min_responses`` have been received, or ``timeout`` time has passed.
Since this call is asynchronous, the Controller :func:`get_num_standing_tasks<nvflare.apis.impl.controller.Controller.get_num_standing_tasks>` method can be used to get the number of standing tasks for synchronization purposes.


Saving & Loading
================

persistor
---------
The :func:`save_model<nvflare.app_common.workflows.model_controller.ModelController.save_model>` and :func:`load_model<nvflare.app_common.workflows.model_controller.ModelController.load_model>`
functions utilize the configured persistor set in the ModelController ``persistor_id: str = "persistor"`` argument.

For the JobAPI, the persistor with ``id = "persistor"`` will automatically be configured based on the type of the model sent to the server.

.. code-block:: python
job.to(Net(), "server")
The persistor can also be configured in ``config_fed_server.json`` in the components section.

custom save & load
------------------
Users can also choose to instead create their own save and load functions rather than use a persistor.

For example we can use PyTorch's save and load functions for the model parameters, and save the FLModel metadata with :mod:`FOBS<nvflare.fuel.utils.fobs>` separately:

.. code-block:: python
import torch
from nvflare.fuel.utils import fobs
def save_model(self, model, filepath=""):
params = model.params
# PyTorch save
torch.save(params, filepath)
# save FLModel metadata
model.params = {}
fobs.dumpf(model, filepath + ".metadata")
model.params = params
def load_model(self, filepath=""):
# PyTorch load
params = torch.load(filepath)
# load FLModel metadata
model = fobs.loadf(filepath + ".metadata")
model.params = params
return model
Note: for non-primitive data types such as ``torch.nn.Module`` (used for the initial PyTorch model), we must configure a corresponding FOBS decomposer for serialization and deserialization.
Read more at :github_nvflare_link:`Flare Object Serializer (FOBS) <nvflare/fuel/utils/fobs/README.rst>`.

.. code-block:: python
from nvflare.app_opt.pt.decomposers import TensorDecomposer
fobs.register(TensorDecomposer)
Additional Functionalities
==========================

In some cases, more advanced FLARE-specific functionalities may be of use.

The :mod:`BaseModelController<nvflare.app_common.workflows.base_model_controller>` class provides access to the engine ``self.engine`` and FLContext ``self.fl_ctx`` if needed.
Functions such as ``get_component()`` and ``build_component()`` can be used to load or dynamically build components.

Furthermore, the underlying :mod:`Controller<nvflare.apis.impl.controller>` class offers additional communication functions and task related utilities.
Many of our pre-existing workflows are based on this lower-level Controller API.
For more details refer to the :ref:`controllers` section.

Configuration
=============

For the JobAPI, define the controller and send it to the server.

.. code-block:: python
controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,
)
job.to(controller, "server")
The controller can also be configured in ``config_fed_server.json`` in the workflows section.

Examples
========

Examples of basic workflows using the ModelController API:

* :github_nvflare_link:`Cyclic <nvflare/app_common/workflows/cyclic.py>`
* :github_nvflare_link:`BaseFedAvg <nvflare/app_common/workflows/base_fedavg.py>`
* :github_nvflare_link:`FedAvg <nvflare/app_common/workflows/fedavg.py>`

Advanced examples:

* :github_nvflare_link:`Scaffold <nvflare/app_common/workflows/scaffold.py>`
* :github_nvflare_link:`FedOpt <nvflare/app_opt/pt/fedopt_ctl.py>`
* :github_nvflare_link:`PTFedAvgEarlyStopping <nvflare/app_opt/pt/fedavg_early_stopping.py>`
* :github_nvflare_link:`Kaplan-Meier <examples/advanced/kaplan-meier-he/src/kaplan_meier_wf_he.py>`
* :github_nvflare_link:`Logistic Regression Newton Raphson <examples/advanced/lr-newton-raphson/job/newton_raphson/app/custom/newton_raphson_workflow.py>`
* :github_nvflare_link:`FedBPT <research/fed-bpt/src/global_es.py>`
4 changes: 2 additions & 2 deletions docs/programming_guide/fl_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
FLModel
=======

We define a standard data structure :mod:`FLModel<nvflare.app_common.abstract.FLModel>`
We define a standard data structure :mod:`FLModel<nvflare.app_common.abstract.fl_model>`
that captures the common attributes needed for exchanging learning results.

This is particularly useful when NVFlare system needs to exchange learning
Expand All @@ -14,4 +14,4 @@ information from received FLModel, run local training, and put the results
in a new FLModel to be sent back.

For a detailed explanation of each attributes, please refer to the API doc:
:mod:`FLModel<nvflare.app_common.abstract.FLModel>`
:mod:`FLModel<nvflare.app_common.abstract.fl_model>`
6 changes: 5 additions & 1 deletion docs/programming_guide/workflows_and_controllers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ A workflow has one or more controllers, each implementing a specific coordinatio
CrossSiteValidation controller implements a strategy to let every client site evaluate every other site's model. You can put together
a workflow that uses any number of controllers.

We have implemented several server controlled federated learning workflows (fed-average, cyclic controller, cross-site evaluation) with the server-side :ref:`controllers <controllers>`.
We provide the FLModel-based :ref:`model_controller` which provides a straightforward way for users to write controllers.
Additionally we also have the lower level :ref:`Controller API <controllers>` with more FLARE-specific concepts, which many of our existing workflows are based upon.

We have implemented several server controlled federated learning workflows (fed-average, cyclic controller, cross-site evaluation) with the server-side controllers.
In these workflows, FL clients get tasks assigned by the controller, execute the tasks, and submit results back to the server.

In certain cases, if the server cannot be trusted, it should not be involved in communication with sensitive information.
Expand All @@ -18,5 +21,6 @@ Please refer to the following sections for more details.
.. toctree::
:maxdepth: 3

controllers/model_controller
controllers/controllers
controllers/client_controlled_workflows
4 changes: 2 additions & 2 deletions nvflare/app_common/workflows/base_model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def save_model(self, model):
self.error("persistor not configured, model will not be saved")

def sample_clients(self, num_clients=None):
clients = self.engine.get_clients()
clients = [client.name for client in self.engine.get_clients()]

if num_clients:
check_positive_int("num_clients", num_clients)
Expand All @@ -366,7 +366,7 @@ def sample_clients(self, num_clients=None):
f"num_clients ({num_clients}) is greater than the number of available clients. Returning all ({len(clients)}) available clients."
)

self.info(f"Sampled clients: {[client.name for client in clients]}")
self.info(f"Sampled clients: {clients}")

return clients

Expand Down
6 changes: 3 additions & 3 deletions nvflare/app_common/workflows/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ def send_model(
callback=callback,
)

def load_model(self):
def load_model(self) -> FLModel:
"""Load initial model from persistor. If persistor is not configured, returns empty FLModel.
Returns:
FLModel
"""
return super().load_model()

def save_model(self, model: FLModel):
def save_model(self, model: FLModel) -> None:
"""Saves model with persistor. If persistor is not configured, does not save.
Args:
Expand All @@ -122,7 +122,7 @@ def save_model(self, model: FLModel):
"""
super().save_model(model)

def sample_clients(self, num_clients=None):
def sample_clients(self, num_clients: int = None) -> List[str]:
"""Returns a list of `num_clients` clients.
Args:
Expand Down

0 comments on commit 0a5436e

Please sign in to comment.