Skip to content

Commit

Permalink
change getting started examples to use BaseFedJob (#2919)
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster authored Sep 6, 2024
1 parent 9b143b8 commit 80bc167
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 344 deletions.
96 changes: 15 additions & 81 deletions examples/getting_started/pt/nvflare_lightning_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@
"metadata": {},
"source": [
"#### 2. Define a FedJob\n",
"The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine."
"The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n",
"\n",
"Here we use a PyTorch `BaseFedJob`, where we can define the job name and the initial global model.\n",
"The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience."
]
},
{
Expand All @@ -329,11 +332,16 @@
"metadata": {},
"outputs": [],
"source": [
"from nvflare import FedJob\n",
"from nvflare.job_config.script_runner import ScriptRunner\n",
"from src.lit_net import LitNet\n",
"\n",
"from nvflare.app_common.workflows.fedavg import FedAvg\n",
"from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob\n",
"from nvflare.job_config.script_runner import ScriptRunner\n",
"\n",
"job = FedJob(name=\"cifar10_lightning_fedavg\")"
"job = BaseFedJob(\n",
" name=\"cifar10_lightning_fedavg\",\n",
" initial_model=LitNet(),\n",
")"
]
},
{
Expand Down Expand Up @@ -361,49 +369,6 @@
"job.to(controller, \"server\")"
]
},
{
"cell_type": "markdown",
"id": "7a63ce0c-ad3e-4434-b2a8-c8f2a4c2e7a5",
"metadata": {},
"source": [
"#### 4. Create Global Model\n",
"Now, we create the initial global model and send to server."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e2c514c-7758-4d30-bb5c-ae3c63be50aa",
"metadata": {},
"outputs": [],
"source": [
"from src.lit_net import LitNet\n",
"from nvflare.app_opt.pt.job_config.model import PTModel\n",
"\n",
"job.to(PTModel(LitNet()), \"server\")"
]
},
{
"cell_type": "markdown",
"id": "72eefb39",
"metadata": {},
"source": [
"#### 5. Add ModelSelector\n",
"Add IntimeModelSelector for global best model selection."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "091beb78",
"metadata": {},
"outputs": [],
"source": [
"from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector\n",
"\n",
"job.to(IntimeModelSelector(key_metric=\"accuracy\"), \"server\")"
]
},
{
"cell_type": "markdown",
"id": "77f5bc7f-4fb4-46e9-8f02-5e7245d95070",
Expand All @@ -412,43 +377,12 @@
"That completes the components that need to be defined on the server."
]
},
{
"cell_type": "markdown",
"id": "32686782",
"metadata": {},
"source": [
"#### OPTIONAL: Define a FedAvgJob\n",
"\n",
"Alternatively, we can replace steps 2-7 and instead use the `FedAvgJob`.\n",
"The `FedAvgJob` automatically configures the `FedAvg`` server controller, along the other components for model persistence and model selection."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02fde3ae",
"metadata": {},
"outputs": [],
"source": [
"from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob\n",
"\n",
"n_clients = 2\n",
"\n",
"# Create FedAvg Job with initial model\n",
"job = FedAvgJob(\n",
" name=\"cifar10_lightning_fedavg\",\n",
" num_rounds=2,\n",
" n_clients=n_clients,\n",
" initial_model=LitNet(),\n",
")"
]
},
{
"cell_type": "markdown",
"id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0",
"metadata": {},
"source": [
"#### 6. Add client ScriptRunners\n",
"#### 4. Add clients\n",
"Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n",
"\n",
"Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity."
Expand All @@ -475,7 +409,7 @@
"source": [
"That's it!\n",
"\n",
"#### 7. Optionally export the job\n",
"#### 5. Optionally export the job\n",
"Now, we could export the job and submit it to a real NVFlare deployment using the [Admin client](https://nvflare.readthedocs.io/en/main/real_world_fl/operation.html) or [FLARE API](https://nvflare.readthedocs.io/en/main/real_world_fl/flare_api.html). "
]
},
Expand All @@ -494,7 +428,7 @@
"id": "9ac3f0a8-06bb-4bea-89d3-4a5fc5b76c63",
"metadata": {},
"source": [
"#### 8. Run FL Simulation\n",
"#### 6. Run FL Simulation\n",
"Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. We can also specify which GPU should be used to run this client, which is helpful for simulated environments. The results will be saved in the specified `workdir`."
]
},
Expand Down
141 changes: 15 additions & 126 deletions examples/getting_started/pt/nvflare_pt_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,10 @@
"metadata": {},
"source": [
"#### 2. Define a FedJob\n",
"The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine."
"The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n",
"\n",
"Here we use a PyTorch `BaseFedJob`, where we can define the job name and the initial global model.\n",
"The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience."
]
},
{
Expand All @@ -271,11 +274,16 @@
"metadata": {},
"outputs": [],
"source": [
"from nvflare import FedJob\n",
"from nvflare.job_config.script_runner import ScriptRunner\n",
"from src.net import Net\n",
"\n",
"from nvflare.app_common.workflows.fedavg import FedAvg\n",
"from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob\n",
"from nvflare.job_config.script_runner import ScriptRunner\n",
"\n",
"job = FedJob(name=\"cifar10_pt_fedavg\")"
"job = BaseFedJob(\n",
" name=\"cifar10_pt_fedavg\",\n",
" initial_model=Net(),\n",
")"
]
},
{
Expand Down Expand Up @@ -303,71 +311,6 @@
"job.to(controller, \"server\")"
]
},
{
"cell_type": "markdown",
"id": "7a63ce0c-ad3e-4434-b2a8-c8f2a4c2e7a5",
"metadata": {},
"source": [
"#### 4. Create Global Model\n",
"Now, we create the initial global model and send to server."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e2c514c-7758-4d30-bb5c-ae3c63be50aa",
"metadata": {},
"outputs": [],
"source": [
"from src.net import Net\n",
"from nvflare.app_opt.pt.job_config.model import PTModel\n",
"\n",
"job.to(PTModel(Net()), \"server\")"
]
},
{
"cell_type": "markdown",
"id": "eccae908",
"metadata": {},
"source": [
"#### 5. Add ModelSelector\n",
"Add IntimeModelSelector for global best model selection."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d52dd194",
"metadata": {},
"outputs": [],
"source": [
"from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector\n",
"\n",
"job.to(IntimeModelSelector(key_metric=\"accuracy\"), \"server\")"
]
},
{
"cell_type": "markdown",
"id": "3fbca796-676c-416a-a500-de3429e4a39f",
"metadata": {},
"source": [
"#### 6. Add TB Receiver\n",
"Add TBAnalyticsReceiver for tensorboard records."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3426c2a-d84a-4ed7-8c23-9e1ae389d215",
"metadata": {},
"outputs": [],
"source": [
"from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver\n",
"\n",
"component = TBAnalyticsReceiver(events=[\"fed.analytix_log_stats\"])\n",
"job.to(id=\"receiver\", obj=component, target=\"server\")"
]
},
{
"cell_type": "markdown",
"id": "77f5bc7f-4fb4-46e9-8f02-5e7245d95070",
Expand All @@ -376,66 +319,12 @@
"That completes the components that need to be defined on the server."
]
},
{
"cell_type": "markdown",
"id": "6059b304",
"metadata": {},
"source": [
"#### 7. Add TB Event\n",
"Add tensorboard logging to clients"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51d8bcda",
"metadata": {},
"outputs": [],
"source": [
"from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent\n",
"\n",
"for i in range(n_clients):\n",
" component = ConvertToFedEvent(events_to_convert=[\"analytix_log_stats\"], fed_event_prefix=\"fed.\")\n",
" job.to(id=\"event_to_fed\", obj=component, target=f\"site-{i+1}\")"
]
},
{
"cell_type": "markdown",
"id": "7c95e3f6",
"metadata": {},
"source": [
"#### OPTIONAL: Define a FedAvgJob\n",
"\n",
"Alternatively, we can replace steps 2-7 and instead use the `FedAvgJob`.\n",
"The `FedAvgJob` automatically configures the `FedAvg`` server controller, along the other components for model persistence, model selection, and TensorBoard streaming.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4dfc3e7",
"metadata": {},
"outputs": [],
"source": [
"from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob\n",
"\n",
"n_clients = 2\n",
"\n",
"# Create FedAvg Job with initial model\n",
"job = FedAvgJob(\n",
" name=\"cifar10_pt_fedavg\",\n",
" num_rounds=2,\n",
" n_clients=n_clients,\n",
" initial_model=Net(),\n",
")"
]
},
{
"cell_type": "markdown",
"id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0",
"metadata": {},
"source": [
"#### 8. Add client ScriptRunners\n",
"#### 4. Add clients\n",
"Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n",
"\n",
"Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity."
Expand All @@ -462,7 +351,7 @@
"source": [
"That's it!\n",
"\n",
"#### 9. Optionally export the job\n",
"#### 5. Optionally export the job\n",
"Now, we could export the job and submit it to a real NVFlare deployment using the [Admin client](https://nvflare.readthedocs.io/en/main/real_world_fl/operation.html) or [FLARE API](https://nvflare.readthedocs.io/en/main/real_world_fl/flare_api.html)."
]
},
Expand All @@ -481,7 +370,7 @@
"id": "9ac3f0a8-06bb-4bea-89d3-4a5fc5b76c63",
"metadata": {},
"source": [
"#### 10. Run FL Simulation\n",
"#### 6. Run FL Simulation\n",
"Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. We can also specify which GPU should be used to run this client, which is helpful for simulated environments. The results will be saved in the specified `workdir`."
]
},
Expand Down
Loading

0 comments on commit 80bc167

Please sign in to comment.