From 80bc167bf95ef1bd19f98a2f438fa539213a2f3b Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Thu, 5 Sep 2024 21:37:54 -0700 Subject: [PATCH] change getting started examples to use BaseFedJob (#2919) --- .../nvflare_lightning_getting_started.ipynb | 96 ++---------- .../pt/nvflare_pt_getting_started.ipynb | 141 ++---------------- .../tf/nvflare_tf_getting_started.ipynb | 98 ++---------- nvflare/app_opt/pt/job_config/base_fed_job.py | 10 +- nvflare/app_opt/pt/job_config/fed_avg.py | 3 + nvflare/app_opt/tf/job_config/base_fed_job.py | 72 +++++++++ nvflare/app_opt/tf/job_config/fed_avg.py | 37 +---- web/src/components/code.astro | 60 +++++--- 8 files changed, 173 insertions(+), 344 deletions(-) create mode 100644 nvflare/app_opt/tf/job_config/base_fed_job.py diff --git a/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb b/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb index f07e1abcf2..6f26be3100 100644 --- a/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb +++ b/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb @@ -319,7 +319,10 @@ "metadata": {}, "source": [ "#### 2. Define a FedJob\n", - "The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine." + "The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n", + "\n", + "Here we use a PyTorch `BaseFedJob`, where we can define the job name and the initial global model.\n", + "The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience." ] }, { @@ -329,11 +332,16 @@ "metadata": {}, "outputs": [], "source": [ - "from nvflare import FedJob\n", - "from nvflare.job_config.script_runner import ScriptRunner\n", + "from src.lit_net import LitNet\n", + "\n", "from nvflare.app_common.workflows.fedavg import FedAvg\n", + "from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob\n", + "from nvflare.job_config.script_runner import ScriptRunner\n", "\n", - "job = FedJob(name=\"cifar10_lightning_fedavg\")" + "job = BaseFedJob(\n", + " name=\"cifar10_lightning_fedavg\",\n", + " initial_model=LitNet(),\n", + ")" ] }, { @@ -361,49 +369,6 @@ "job.to(controller, \"server\")" ] }, - { - "cell_type": "markdown", - "id": "7a63ce0c-ad3e-4434-b2a8-c8f2a4c2e7a5", - "metadata": {}, - "source": [ - "#### 4. Create Global Model\n", - "Now, we create the initial global model and send to server." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0e2c514c-7758-4d30-bb5c-ae3c63be50aa", - "metadata": {}, - "outputs": [], - "source": [ - "from src.lit_net import LitNet\n", - "from nvflare.app_opt.pt.job_config.model import PTModel\n", - "\n", - "job.to(PTModel(LitNet()), \"server\")" - ] - }, - { - "cell_type": "markdown", - "id": "72eefb39", - "metadata": {}, - "source": [ - "#### 5. Add ModelSelector\n", - "Add IntimeModelSelector for global best model selection." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "091beb78", - "metadata": {}, - "outputs": [], - "source": [ - "from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector\n", - "\n", - "job.to(IntimeModelSelector(key_metric=\"accuracy\"), \"server\")" - ] - }, { "cell_type": "markdown", "id": "77f5bc7f-4fb4-46e9-8f02-5e7245d95070", @@ -412,43 +377,12 @@ "That completes the components that need to be defined on the server." ] }, - { - "cell_type": "markdown", - "id": "32686782", - "metadata": {}, - "source": [ - "#### OPTIONAL: Define a FedAvgJob\n", - "\n", - "Alternatively, we can replace steps 2-7 and instead use the `FedAvgJob`.\n", - "The `FedAvgJob` automatically configures the `FedAvg`` server controller, along the other components for model persistence and model selection." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "02fde3ae", - "metadata": {}, - "outputs": [], - "source": [ - "from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob\n", - "\n", - "n_clients = 2\n", - "\n", - "# Create FedAvg Job with initial model\n", - "job = FedAvgJob(\n", - " name=\"cifar10_lightning_fedavg\",\n", - " num_rounds=2,\n", - " n_clients=n_clients,\n", - " initial_model=LitNet(),\n", - ")" - ] - }, { "cell_type": "markdown", "id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0", "metadata": {}, "source": [ - "#### 6. Add client ScriptRunners\n", + "#### 4. Add clients\n", "Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n", "\n", "Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity." @@ -475,7 +409,7 @@ "source": [ "That's it!\n", "\n", - "#### 7. Optionally export the job\n", + "#### 5. Optionally export the job\n", "Now, we could export the job and submit it to a real NVFlare deployment using the [Admin client](https://nvflare.readthedocs.io/en/main/real_world_fl/operation.html) or [FLARE API](https://nvflare.readthedocs.io/en/main/real_world_fl/flare_api.html). " ] }, @@ -494,7 +428,7 @@ "id": "9ac3f0a8-06bb-4bea-89d3-4a5fc5b76c63", "metadata": {}, "source": [ - "#### 8. Run FL Simulation\n", + "#### 6. Run FL Simulation\n", "Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. We can also specify which GPU should be used to run this client, which is helpful for simulated environments. The results will be saved in the specified `workdir`." ] }, diff --git a/examples/getting_started/pt/nvflare_pt_getting_started.ipynb b/examples/getting_started/pt/nvflare_pt_getting_started.ipynb index 8bcfc67a7e..79e1b99259 100644 --- a/examples/getting_started/pt/nvflare_pt_getting_started.ipynb +++ b/examples/getting_started/pt/nvflare_pt_getting_started.ipynb @@ -261,7 +261,10 @@ "metadata": {}, "source": [ "#### 2. Define a FedJob\n", - "The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine." + "The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n", + "\n", + "Here we use a PyTorch `BaseFedJob`, where we can define the job name and the initial global model.\n", + "The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience." ] }, { @@ -271,11 +274,16 @@ "metadata": {}, "outputs": [], "source": [ - "from nvflare import FedJob\n", - "from nvflare.job_config.script_runner import ScriptRunner\n", + "from src.net import Net\n", + "\n", "from nvflare.app_common.workflows.fedavg import FedAvg\n", + "from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob\n", + "from nvflare.job_config.script_runner import ScriptRunner\n", "\n", - "job = FedJob(name=\"cifar10_pt_fedavg\")" + "job = BaseFedJob(\n", + " name=\"cifar10_pt_fedavg\",\n", + " initial_model=Net(),\n", + ")" ] }, { @@ -303,71 +311,6 @@ "job.to(controller, \"server\")" ] }, - { - "cell_type": "markdown", - "id": "7a63ce0c-ad3e-4434-b2a8-c8f2a4c2e7a5", - "metadata": {}, - "source": [ - "#### 4. Create Global Model\n", - "Now, we create the initial global model and send to server." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0e2c514c-7758-4d30-bb5c-ae3c63be50aa", - "metadata": {}, - "outputs": [], - "source": [ - "from src.net import Net\n", - "from nvflare.app_opt.pt.job_config.model import PTModel\n", - "\n", - "job.to(PTModel(Net()), \"server\")" - ] - }, - { - "cell_type": "markdown", - "id": "eccae908", - "metadata": {}, - "source": [ - "#### 5. Add ModelSelector\n", - "Add IntimeModelSelector for global best model selection." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d52dd194", - "metadata": {}, - "outputs": [], - "source": [ - "from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector\n", - "\n", - "job.to(IntimeModelSelector(key_metric=\"accuracy\"), \"server\")" - ] - }, - { - "cell_type": "markdown", - "id": "3fbca796-676c-416a-a500-de3429e4a39f", - "metadata": {}, - "source": [ - "#### 6. Add TB Receiver\n", - "Add TBAnalyticsReceiver for tensorboard records." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c3426c2a-d84a-4ed7-8c23-9e1ae389d215", - "metadata": {}, - "outputs": [], - "source": [ - "from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver\n", - "\n", - "component = TBAnalyticsReceiver(events=[\"fed.analytix_log_stats\"])\n", - "job.to(id=\"receiver\", obj=component, target=\"server\")" - ] - }, { "cell_type": "markdown", "id": "77f5bc7f-4fb4-46e9-8f02-5e7245d95070", @@ -376,66 +319,12 @@ "That completes the components that need to be defined on the server." ] }, - { - "cell_type": "markdown", - "id": "6059b304", - "metadata": {}, - "source": [ - "#### 7. Add TB Event\n", - "Add tensorboard logging to clients" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "51d8bcda", - "metadata": {}, - "outputs": [], - "source": [ - "from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent\n", - "\n", - "for i in range(n_clients):\n", - " component = ConvertToFedEvent(events_to_convert=[\"analytix_log_stats\"], fed_event_prefix=\"fed.\")\n", - " job.to(id=\"event_to_fed\", obj=component, target=f\"site-{i+1}\")" - ] - }, - { - "cell_type": "markdown", - "id": "7c95e3f6", - "metadata": {}, - "source": [ - "#### OPTIONAL: Define a FedAvgJob\n", - "\n", - "Alternatively, we can replace steps 2-7 and instead use the `FedAvgJob`.\n", - "The `FedAvgJob` automatically configures the `FedAvg`` server controller, along the other components for model persistence, model selection, and TensorBoard streaming.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c4dfc3e7", - "metadata": {}, - "outputs": [], - "source": [ - "from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob\n", - "\n", - "n_clients = 2\n", - "\n", - "# Create FedAvg Job with initial model\n", - "job = FedAvgJob(\n", - " name=\"cifar10_pt_fedavg\",\n", - " num_rounds=2,\n", - " n_clients=n_clients,\n", - " initial_model=Net(),\n", - ")" - ] - }, { "cell_type": "markdown", "id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0", "metadata": {}, "source": [ - "#### 8. Add client ScriptRunners\n", + "#### 4. Add clients\n", "Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n", "\n", "Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity." @@ -462,7 +351,7 @@ "source": [ "That's it!\n", "\n", - "#### 9. Optionally export the job\n", + "#### 5. Optionally export the job\n", "Now, we could export the job and submit it to a real NVFlare deployment using the [Admin client](https://nvflare.readthedocs.io/en/main/real_world_fl/operation.html) or [FLARE API](https://nvflare.readthedocs.io/en/main/real_world_fl/flare_api.html)." ] }, @@ -481,7 +370,7 @@ "id": "9ac3f0a8-06bb-4bea-89d3-4a5fc5b76c63", "metadata": {}, "source": [ - "#### 10. Run FL Simulation\n", + "#### 6. Run FL Simulation\n", "Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. We can also specify which GPU should be used to run this client, which is helpful for simulated environments. The results will be saved in the specified `workdir`." ] }, diff --git a/examples/getting_started/tf/nvflare_tf_getting_started.ipynb b/examples/getting_started/tf/nvflare_tf_getting_started.ipynb index 655287f685..61afb4f870 100644 --- a/examples/getting_started/tf/nvflare_tf_getting_started.ipynb +++ b/examples/getting_started/tf/nvflare_tf_getting_started.ipynb @@ -251,7 +251,10 @@ "metadata": {}, "source": [ "#### 2. Define a FedJob\n", - "The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine." + "The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n", + "\n", + "Here we use a TensorFlow `BaseFedJob`, where we can define the job name and the initial global model.\n", + "The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience." ] }, { @@ -261,11 +264,16 @@ "metadata": {}, "outputs": [], "source": [ - "from nvflare import FedJob\n", - "from nvflare.job_config.script_runner import FrameworkType, ScriptRunner\n", + "from src.tf_net import TFNet\n", + "\n", "from nvflare.app_common.workflows.fedavg import FedAvg\n", + "from nvflare.app_opt.tf.job_config.base_fed_job import BaseFedJob\n", + "from nvflare.job_config.script_runner import FrameworkType, ScriptRunner\n", "\n", - "job = FedJob(name=\"cifar10_tf_fedavg\")" + "job = BaseFedJob(\n", + " name=\"cifar10_tf_fedavg\",\n", + " initial_model=TFNet(),\n", + ")" ] }, { @@ -293,49 +301,6 @@ "job.to(controller, \"server\")" ] }, - { - "cell_type": "markdown", - "id": "7a63ce0c-ad3e-4434-b2a8-c8f2a4c2e7a5", - "metadata": {}, - "source": [ - "#### 4. Create Global Model\n", - "Now, we create the initial global model and send to server." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0e2c514c-7758-4d30-bb5c-ae3c63be50aa", - "metadata": {}, - "outputs": [], - "source": [ - "from src.tf_net import TFNet\n", - "from nvflare.app_opt.tf.job_config.model import TFModel\n", - "\n", - "job.to(TFModel(TFNet()), \"server\")" - ] - }, - { - "cell_type": "markdown", - "id": "25c6eada", - "metadata": {}, - "source": [ - "#### 5. Add ModelSelector\n", - "Add IntimeModelSelector for global best model selection." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0ae73e50", - "metadata": {}, - "outputs": [], - "source": [ - "from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector\n", - "\n", - "job.to(IntimeModelSelector(key_metric=\"accuracy\"), \"server\")" - ] - }, { "cell_type": "markdown", "id": "77f5bc7f-4fb4-46e9-8f02-5e7245d95070", @@ -344,43 +309,12 @@ "That completes the components that need to be defined on the server." ] }, - { - "cell_type": "markdown", - "id": "b89f7442", - "metadata": {}, - "source": [ - "#### OPTIONAL: Define a FedAvgJob\n", - "\n", - "Alternatively, we can replace steps 2-7 and instead use the `FedAvgJob`.\n", - "The `FedAvgJob` automatically configures the `FedAvg` server controller, along the other components for model persistence and model selection." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0fd6a908", - "metadata": {}, - "outputs": [], - "source": [ - "from nvflare.app_opt.tf.job_config.fed_avg import FedAvgJob\n", - "\n", - "n_clients = 2\n", - "\n", - "# Create FedAvg Job with initial model\n", - "job = FedAvgJob(\n", - " name=\"cifar10_tf_fedavg\",\n", - " num_rounds=2,\n", - " n_clients=n_clients,\n", - " initial_model=TFNet(input_shape=(None, 32, 32, 3)),\n", - ")" - ] - }, { "cell_type": "markdown", "id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0", "metadata": {}, "source": [ - "#### 6. Add client ScriptRunners\n", + "#### 4. Add clients\n", "Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n", "\n", "Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity." @@ -408,7 +342,7 @@ "source": [ "That's it!\n", "\n", - "#### 7. Optionally export the job\n", + "#### 5. Optionally export the job\n", "Now, we could export the job and submit it to a real NVFlare deployment using the [Admin client](https://nvflare.readthedocs.io/en/main/real_world_fl/operation.html) or [FLARE API](https://nvflare.readthedocs.io/en/main/real_world_fl/flare_api.html). " ] }, @@ -427,7 +361,7 @@ "id": "9ac3f0a8-06bb-4bea-89d3-4a5fc5b76c63", "metadata": {}, "source": [ - "#### 8. Run FL Simulation\n", + "#### 6. Run FL Simulation\n", "Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. We can also specify which GPU should be used to run this client, which is helpful for simulated environments. The results will be saved in the specified `workdir`." ] }, @@ -466,7 +400,7 @@ "id": "387662f4-7d05-4840-bcc7-a2523e03c2c2", "metadata": {}, "source": [ - "#### 9. Next steps\n", + "#### 7. Next steps\n", "\n", "Continue with the steps described in the [README.md](README.md) to run more experiments with a more complex model and more advanced FL algorithms. " ] diff --git a/nvflare/app_opt/pt/job_config/base_fed_job.py b/nvflare/app_opt/pt/job_config/base_fed_job.py index 5499f31c51..ce83a886a8 100644 --- a/nvflare/app_opt/pt/job_config/base_fed_job.py +++ b/nvflare/app_opt/pt/job_config/base_fed_job.py @@ -27,7 +27,7 @@ class BaseFedJob(FedJob): def __init__( self, - initial_model: nn.Module, + initial_model: nn.Module = None, name: str = "fed_job", min_clients: int = 1, mandatory_clients: Optional[List[str]] = None, @@ -40,8 +40,8 @@ def __init__( User must add executors. Args: - initial_model (nn.Module): initial PyTorch Model - name (name, optional): name of the job. Defaults to "fed_job" + initial_model (nn.Module): initial PyTorch Model. Defaults to None. + name (name, optional): name of the job. Defaults to "fed_job". min_clients (int, optional): the minimum number of clients for the job. Defaults to 1. mandatory_clients (List[str], optional): mandatory clients to run the job. Default None. key_metric (str, optional): Metric used to determine if the model is globally best. @@ -51,6 +51,7 @@ def __init__( super().__init__(name, min_clients, mandatory_clients) self.key_metric = key_metric self.initial_model = initial_model + self.comp_ids = {} component = ValidationJsonGenerator() self.to_server(id="json_generator", obj=component) @@ -63,7 +64,8 @@ def __init__( component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"]) self.to_server(id="receiver", obj=component) - self.comp_ids = self.to_server(PTModel(initial_model)) + if initial_model: + self.comp_ids.update(self.to_server(PTModel(initial_model))) def set_up_client(self, target: str): component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.") diff --git a/nvflare/app_opt/pt/job_config/fed_avg.py b/nvflare/app_opt/pt/job_config/fed_avg.py index 9efc8e8cb3..58f10e27a5 100644 --- a/nvflare/app_opt/pt/job_config/fed_avg.py +++ b/nvflare/app_opt/pt/job_config/fed_avg.py @@ -47,6 +47,9 @@ def __init__( if metrics are a `dict`, `key_metric` can select the metric used for global model selection. Defaults to "accuracy". """ + if not isinstance(initial_model, nn.Module): + raise ValueError(f"Expected initial model to be nn.Module, but got type f{type(initial_model)}.") + super().__init__(initial_model, name, min_clients, mandatory_clients, key_metric) controller = FedAvg( diff --git a/nvflare/app_opt/tf/job_config/base_fed_job.py b/nvflare/app_opt/tf/job_config/base_fed_job.py new file mode 100644 index 0000000000..173b39af3e --- /dev/null +++ b/nvflare/app_opt/tf/job_config/base_fed_job.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import tensorflow as tf + +from nvflare import FedJob +from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent +from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector +from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator +from nvflare.app_opt.tf.job_config.model import TFModel +from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver + + +class BaseFedJob(FedJob): + def __init__( + self, + initial_model: tf.keras.Model = None, + name: str = "fed_job", + min_clients: int = 1, + mandatory_clients: Optional[List[str]] = None, + key_metric: str = "accuracy", + ): + """TensorFlow BaseFedJob. + + Configures server side FedAvg controller, persistor with initial model, and widgets. + + User must add executors. + + Args: + initial_model (tf.keras.Model): initial TensorFlow Model. Defaults to None. + name (name, optional): name of the job. Defaults to "fed_job". + min_clients (int, optional): the minimum number of clients for the job. Defaults to 1. + mandatory_clients (List[str], optional): mandatory clients to run the job. Default None. + key_metric (str, optional): Metric used to determine if the model is globally best. + if metrics are a `dict`, `key_metric` can select the metric used for global model selection. + Defaults to "accuracy". + """ + super().__init__(name, min_clients, mandatory_clients) + self.key_metric = key_metric + self.initial_model = initial_model + self.comp_ids = {} + + component = ValidationJsonGenerator() + self.to_server(id="json_generator", obj=component) + + if self.key_metric: + component = IntimeModelSelector(key_metric=self.key_metric) + self.to_server(id="model_selector", obj=component) + + # TODO: make different tracking receivers configurable + component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"]) + self.to_server(id="receiver", obj=component) + + if initial_model: + self.comp_ids["persistor_id"] = self.to_server(TFModel(initial_model)) + + def set_up_client(self, target: str): + component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.") + self.to(id="event_to_fed", obj=component, target=target) diff --git a/nvflare/app_opt/tf/job_config/fed_avg.py b/nvflare/app_opt/tf/job_config/fed_avg.py index 2c2dadf015..e25d87f574 100644 --- a/nvflare/app_opt/tf/job_config/fed_avg.py +++ b/nvflare/app_opt/tf/job_config/fed_avg.py @@ -15,16 +15,11 @@ import tensorflow as tf -from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent -from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector -from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator from nvflare.app_common.workflows.fedavg import FedAvg -from nvflare.app_opt.tf.job_config.model import TFModel -from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver -from nvflare.job_config.api import FedJob +from nvflare.app_opt.tf.job_config.base_fed_job import BaseFedJob -class FedAvgJob(FedJob): +class FedAvgJob(BaseFedJob): def __init__( self, initial_model: tf.keras.Model, @@ -42,7 +37,7 @@ def __init__( User must add executors. Args: - initial_model (nn.Module): initial TensorFlow Model + initial_model (tf.keras.Model): initial TensorFlow Model n_clients (int): number of clients for this job num_rounds (int): number of rounds for FedAvg name (name, optional): name of the job. Defaults to "fed_job" @@ -52,32 +47,14 @@ def __init__( if metrics are a `dict`, `key_metric` can select the metric used for global model selection. Defaults to "accuracy". """ - super().__init__(name, min_clients, mandatory_clients) - self.key_metric = key_metric - self.initial_model = initial_model - self.num_rounds = num_rounds - self.n_clients = n_clients + if not isinstance(initial_model, tf.keras.Model): + raise ValueError(f"Expected initial model to be tf.keras.Model, but got type f{type(initial_model)}.") - component = ValidationJsonGenerator() - self.to_server(id="json_generator", obj=component) - - if self.key_metric: - component = IntimeModelSelector(key_metric=self.key_metric) - self.to_server(id="model_selector", obj=component) - - # TODO: make different tracking receivers configurable - component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"]) - self.to_server(id="receiver", obj=component) - - persistor_id = self.to_server(TFModel(initial_model)) + super().__init__(initial_model, name, min_clients, mandatory_clients, key_metric) controller = FedAvg( num_clients=n_clients, num_rounds=num_rounds, - persistor_id=persistor_id, + persistor_id=self.comp_ids["persistor_id"], ) self.to_server(controller) - - def set_up_client(self, target: str): - component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.") - self.to(id="event_to_fed", obj=component, target=target) diff --git a/web/src/components/code.astro b/web/src/components/code.astro index 6ec36d85a0..98ad476d88 100644 --- a/web/src/components/code.astro +++ b/web/src/components/code.astro @@ -203,7 +203,8 @@ class FedAvg(BaseFedAvg): const jobCode_pt = ` from cifar10_pt_fl import Net -from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob +from nvflare.app_common.workflows.fedavg import FedAvg +from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob from nvflare.job_config.script_runner import ScriptRunner if __name__ == "__main__": @@ -211,13 +212,18 @@ if __name__ == "__main__": num_rounds = 2 train_script = "cifar10_pt_fl.py" - # Create FedAvg Job with initial model - job = FedAvgJob( - name="cifar10_pt_fedavg", + # Create BaseFedJob with initial model + job = BaseFedJob( + name="cifar10_pt_fedavg", + initial_model=Net(), + ) + + # Define the controller and send to server + controller = FedAvg( + num_clients=n_clients, num_rounds=num_rounds, - n_clients=n_clients, - initial_model=Net(), ) + job.to_server(controller) # Add clients for i in range(n_clients): @@ -416,7 +422,8 @@ class FedAvg(BaseFedAvg): const jobCode_lt = ` from cifar10_lightning_fl import LitNet -from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob +from nvflare.app_common.workflows.fedavg import FedAvg +from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob from nvflare.job_config.script_runner import ScriptRunner if __name__ == "__main__": @@ -424,13 +431,18 @@ if __name__ == "__main__": num_rounds = 2 train_script = "cifar10_lightning_fl.py" - # Create FedAvg Job with initial model - job = FedAvgJob( - name="cifar10_lightning_fedavg", + # Create BaseFedJob with initial model + job = BaseFedJob( + name="cifar10_lightning_fedavg", + initial_model=LitNet(), + ) + + # Define the controller and send to server + controller = FedAvg( + num_clients=n_clients, num_rounds=num_rounds, - n_clients=n_clients, - initial_model=LitNet(), ) + job.to_server(controller) # Add clients for i in range(n_clients): @@ -572,7 +584,8 @@ class FedAvg(BaseFedAvg): const jobCode_tf = ` from cifar10_tf_fl import TFNet -from nvflare.app_opt.tf.job_config.fed_avg import FedAvgJob +from nvflare.app_common.workflows.fedavg import FedAvg +from nvflare.app_opt.tf.job_config.base_fed_job import BaseFedJob from nvflare.job_config.script_runner import FrameworkType, ScriptRunner if __name__ == "__main__": @@ -580,13 +593,18 @@ if __name__ == "__main__": num_rounds = 2 train_script = "cifar10_tf_fl.py" - # Create FedAvg Job with initial model - job = FedAvgJob( - name="cifar10_tf_fedavg", + # Create BaseFedJob with initial model + job = BaseFedJob( + name="cifar10_tf_fedavg", + initial_model=TFNet(input_shape=(None, 32, 32, 3)), + ) + + # Define the controller and send to server + controller = FedAvg( + num_clients=n_clients, num_rounds=num_rounds, - n_clients=n_clients, - initial_model=TFNet(input_shape=(None, 32, 32, 3)), ) + job.to_server(controller) # Add clients for i in range(n_clients): @@ -645,7 +663,7 @@ const frameworks = [ framework: "pytorch", title: "Job Code (fedavg_cifar10_pt_job.py)", description: - "Lastly we construct the job with our 'cifar10_pt_fl.py' client script and FedAvgJob. The FedAvgJob configures our 'FedAvg' server controller, along with other useful components for model persistance, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", + "Lastly we construct the job with our 'cifar10_pt_fl.py' client script and 'FedAvg' server controller. The BaseFedJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", code: jobCode_pt, }, { @@ -699,7 +717,7 @@ const frameworks = [ framework: "lightning", title: "Job Code (fedavg_cifar10_lightning_job.py)", description: - "Lastly we construct the job with our 'cifar10_lightning_fl.py' client script and FedAvgJob. The FedAvgJob configures our 'FedAvg' server controller, along with other useful components for model persistance, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", + "Lastly we construct the job with our 'cifar10_lightning_fl.py' client script and 'FedAvg' server controller. The BaseFedJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", code: jobCode_lt, }, { @@ -753,7 +771,7 @@ const frameworks = [ framework: "tensorflow", title: "Job Code (fedavg_cifar10_tf_job.py)", description: - "Lastly we construct the job with our 'cifar10_tf_fl.py' client script and FedAvgJob. The FedAvgJob configures our 'FedAvg' server controller, along with other useful components for model persistance, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", + "Lastly we construct the job with our 'cifar10_tf_fl.py' client script and 'FedAvg' server controller. The BaseFedJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", code: jobCode_tf, }, {