forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add monai notebook * add training script * update example * update notebook * use job template * call init later * swith back * add gitignore * update notebooks * add readmes * send received model to GPU * use monai tb stats handler * formatting
- Loading branch information
1 parent
f251451
commit 46b8d2a
Showing
15 changed files
with
907 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# nvflare artifacts for this example | ||
fedavg_workspace | ||
jobs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
## Converting MONAI Code to a Federated Learning Setting | ||
|
||
In this tutorial, we will introduce how simple it can be to run an end-to-end classification pipeline with MONAI | ||
and deploy it in a federated learning setting using NVFlare. | ||
|
||
### 1. Standalone training with MONAI | ||
[monai_101.ipynb](./monai_101.ipynb) is based on the [MONAI 101 classification tutorial](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb) and shows each step required in only a few lines of code, including | ||
|
||
- Dataset download | ||
- Data pre-processing | ||
- Define a DenseNet-121 and run training | ||
- Check the results on test dataset | ||
|
||
### 2. Federated learning with MONAI | ||
[monai_101_fl.ipynb](./monai_101_fl.ipynb) shows how we can simply put the code introduced above into a Python script and convert it to running in an FL scenario using NVFlare. | ||
|
||
To achieve this, we utilize the [`FedAvg`](https://arxiv.org/abs/1602.05629) algorithm and NVFlare's [Client | ||
API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type.html#client-api). |
169 changes: 169 additions & 0 deletions
169
integration/monai/examples/mednist/code/monai_mednist_train.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# MONAI Example adopted from https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb | ||
# | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
import os | ||
import sys | ||
import tempfile | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import torch | ||
from monai.apps import MedNISTDataset | ||
from monai.config import print_config | ||
from monai.data import DataLoader | ||
from monai.engines import SupervisedTrainer | ||
from monai.handlers import StatsHandler, TensorBoardStatsHandler | ||
from monai.inferers import SimpleInferer | ||
from monai.networks import eval_mode | ||
from monai.networks.nets import densenet121 | ||
from monai.transforms import Compose, EnsureChannelFirstD, LoadImageD, ScaleIntensityD | ||
|
||
# (1) import nvflare client API | ||
import nvflare.client as flare | ||
|
||
# (optional) metrics | ||
from nvflare.client.tracking import SummaryWriter | ||
|
||
print_config() | ||
|
||
|
||
def main(): | ||
# (2) initializes NVFlare client API | ||
flare.init() | ||
|
||
# Setup data directory | ||
directory = os.environ.get("MONAI_DATA_DIRECTORY") | ||
root_dir = tempfile.mkdtemp() if directory is None else directory | ||
print(root_dir) | ||
|
||
# Use MONAI transforms to preprocess data | ||
transform = Compose( | ||
[ | ||
LoadImageD(keys="image", image_only=True), | ||
EnsureChannelFirstD(keys="image"), | ||
ScaleIntensityD(keys="image"), | ||
] | ||
) | ||
|
||
# Prepare datasets using MONAI Apps | ||
dataset = MedNISTDataset(root_dir=root_dir, transform=transform, section="training", download=True) | ||
|
||
# Define a network and a supervised trainer | ||
|
||
# If available, we use GPU to speed things up. | ||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
max_epochs = 1 # rather than 5 epochs, we run 5 FL rounds with 1 local epoch each. | ||
model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE) | ||
|
||
train_loader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4) | ||
|
||
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | ||
trainer = SupervisedTrainer( | ||
device=torch.device(DEVICE), | ||
max_epochs=max_epochs, | ||
train_data_loader=train_loader, | ||
network=model, | ||
optimizer=torch.optim.Adam(model.parameters(), lr=1e-5), | ||
loss_function=torch.nn.CrossEntropyLoss(), | ||
inferer=SimpleInferer(), | ||
train_handlers=StatsHandler(), | ||
) | ||
|
||
# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler | ||
summary_writer = SummaryWriter() | ||
train_tensorboard_stats_handler = TensorBoardStatsHandler(summary_writer=summary_writer) | ||
train_tensorboard_stats_handler.attach(trainer) | ||
|
||
# (optional) calculate total steps | ||
steps = max_epochs * len(train_loader) | ||
# Run the training | ||
|
||
while flare.is_running(): | ||
# (3) receives FLModel from NVFlare | ||
input_model = flare.receive() | ||
print(f"current_round={input_model.current_round}") | ||
|
||
# (4) loads model from NVFlare and sends it to GPU | ||
trainer.network.load_state_dict(input_model.params) | ||
trainer.network.to(DEVICE) | ||
|
||
trainer.run() | ||
|
||
# (5) wraps evaluation logic into a method to re-use for | ||
# evaluation on both trained and received model | ||
def evaluate(input_weights): | ||
# Create model for evaluation | ||
eval_model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE) | ||
eval_model.load_state_dict(input_weights) | ||
|
||
# Check the prediction on the test dataset | ||
dataset_dir = Path(root_dir, "MedNIST") | ||
class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir()) | ||
testdata = MedNISTDataset( | ||
root_dir=root_dir, transform=transform, section="test", download=False, runtime_cache=True | ||
) | ||
correct = 0 | ||
total = 0 | ||
max_items_to_print = 10 | ||
_print = 0 | ||
with eval_mode(eval_model): | ||
for item in DataLoader(testdata, batch_size=512, num_workers=0): # changed to do batch processing | ||
prob = np.array(eval_model(item["image"].to(DEVICE)).detach().to("cpu")) | ||
pred = [class_names[p] for p in prob.argmax(axis=1)] | ||
gt = item["class_name"] | ||
# changed the logic a bit from tutorial to compute accuracy on full test set | ||
# but only print for some. | ||
for _gt, _pred in zip(gt, pred): | ||
if _print < max_items_to_print: | ||
print(f"Class prediction is {_pred}. Ground-truth: {_gt}") | ||
_print += 1 | ||
|
||
# compute accuracy | ||
total += 1 | ||
correct += float(_pred == _gt) | ||
|
||
print(f"Accuracy of the network on the {total} test images: {100 * correct // total} %") | ||
return correct / total | ||
|
||
# (6) evaluate on received model for model selection | ||
accuracy = evaluate(input_model.params) | ||
summary_writer.add_scalar(tag="global_model_accuracy", scalar=accuracy, global_step=input_model.current_round) | ||
|
||
# (7) construct trained FL model | ||
output_model = flare.FLModel( | ||
params=trainer.network.cpu().state_dict(), | ||
metrics={"accuracy": accuracy}, | ||
meta={"NUM_STEPS_CURRENT_ROUND": steps}, | ||
) | ||
# (8) send model back to NVFlare | ||
flare.send(output_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
67 changes: 67 additions & 0 deletions
67
integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_client.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
|
||
# This is the application script which will be invoked. Client can replace this script with user's own training script. | ||
app_script = "monai_mednist_train.py" | ||
|
||
# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. | ||
# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. | ||
app_config = "" | ||
|
||
# Client Computing Executors. | ||
executors = [ | ||
{ | ||
# tasks the executors are defined to handle | ||
tasks = ["train"] | ||
|
||
# This particular executor | ||
executor { | ||
|
||
path = "nvflare.app_opt.pt.in_process_client_api_executor.PTInProcessClientAPIExecutor" | ||
args { | ||
task_script_path = "{app_script}" | ||
task_script_args = "{app_config}" | ||
|
||
# if the transfer_type is FULL, then it will be sent directly | ||
# if the transfer_type is DIFF, then we will calculate the | ||
# difference VS received parameters and send the difference | ||
params_transfer_type = "DIFF" | ||
|
||
# if train_with_evaluation is true, the executor will expect | ||
# the custom code need to send back both the trained parameters and the evaluation metric | ||
# otherwise only trained parameters are expected | ||
train_with_evaluation = true | ||
|
||
# time interval in seconds. Time interval to wait before check if the local task has submitted the result | ||
# if the local task takes long time, you can increase this interval to larger number | ||
# uncomment to overwrite the default, default is 0.5 seconds | ||
result_pull_interval = 0.5 | ||
|
||
# time interval in seconds. Time interval to wait before check if the trainig code has log metric (such as | ||
# Tensorboard log, MLFlow log or Weights & Biases logs. The result will be streanmed to the server side | ||
# then to the corresponding tracking system | ||
# if the log is not needed, you can set this to a larger number | ||
# uncomment to overwrite the default, default is None, which disable the log streaming feature. | ||
log_pull_interval = 0.1 | ||
|
||
} | ||
} | ||
} | ||
], | ||
|
||
# this defined an array of task data filters. If provided, it will control the data from server controller to client executor | ||
task_data_filters = [] | ||
|
||
# this defined an array of task result filters. If provided, it will control the result from client executor to server controller | ||
task_result_filters = [] | ||
|
||
# define this component that will help relay local metrics log to FL server. | ||
components = [ | ||
{ | ||
"id": "event_to_fed", | ||
"name": "ConvertToFedEvent", | ||
"args": {"events_to_convert": ["analytix_log_stats"], "fed_event_prefix": "fed."} | ||
} | ||
] | ||
} |
94 changes: 94 additions & 0 deletions
94
integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_server.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
|
||
# task data filter: if filters are provided, the filter will filter the data flow out of server to client. | ||
task_data_filters =[] | ||
|
||
# task result filter: if filters are provided, the filter will filter the result flow out of client to server. | ||
task_result_filters = [] | ||
|
||
# This assumes that there will be a "net.py" file with class name "Net". | ||
# If your model code is not in "net.py" and class name is not "Net", please modify here | ||
model_class_path = "monai.networks.nets.densenet121" | ||
|
||
# densenet arguments | ||
spatial_dims = 2 | ||
in_channels = 1 | ||
out_channels = 6 | ||
|
||
# workflows: Array of workflows the control the Federated Learning workflow lifecycle. | ||
# One can specify multiple workflows. The NVFLARE will run them in the order specified. | ||
workflows = [ | ||
{ | ||
# 1st workflow" | ||
id = "scatter_and_gather" | ||
|
||
# name = ScatterAndGather, path is the class path of the ScatterAndGather controller. | ||
path = "nvflare.app_common.workflows.fedavg.FedAvg" | ||
args { | ||
# argument of the ScatterAndGather class. | ||
# min number of clients required for ScatterAndGather controller to move to the next round | ||
# during the workflow cycle. The controller will wait until the min_clients returned from clients | ||
# before move to the next step. | ||
min_clients = 2 | ||
|
||
# number of global round of the training. | ||
num_rounds = 5 | ||
} | ||
} | ||
] | ||
|
||
# List of components used in the server side workflow. | ||
components = [ | ||
{ | ||
# This is the persistence component used in above workflow. | ||
# PTFileModelPersistor is a Pytorch persistor which save/read the model to/from file. | ||
|
||
id = "persistor" | ||
path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" | ||
|
||
# the persitor class take model class as argument | ||
# This imply that the model is initialized from the server-side. | ||
# The initialized model will be broadcast to all the clients to start the training. | ||
args.model.path = "{model_class_path}" | ||
args.model.args.spatial_dims = "{spatial_dims}" | ||
args.model.args.in_channels = "{in_channels}" | ||
args.model.args.out_channels = "{out_channels}" | ||
}, | ||
{ | ||
# This component is not directly used in Workflow. | ||
# it select the best model based on the incoming global validation metrics. | ||
id = "model_selector" | ||
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" | ||
# need to make sure this "key_metric" match what server side received | ||
args.key_metric = "accuracy" | ||
}, | ||
{ | ||
id = "receiver" | ||
path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" | ||
args.events = ["fed.analytix_log_stats"] | ||
}, | ||
{ | ||
id = "mlflow_receiver" | ||
path = "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver" | ||
args { | ||
# tracking_uri = "http://0.0.0.0:5000" | ||
tracking_uri = "" | ||
kwargs { | ||
experiment_name = "nvflare-fedavg-mednist-experiment" | ||
run_name = "nvflare-fedavg-mednist-with-mlflow" | ||
experiment_tags { | ||
"mlflow.note.content": "## **NVFlare FedAvg MONAI experiment with MLflow**" | ||
} | ||
run_tags { | ||
"mlflow.note.content" = "## Federated Experiment tracking with MONAI and MLflow \n###" | ||
} | ||
} | ||
artifact_location = "artifacts" | ||
events = ["fed.analytix_log_stats"] | ||
} | ||
} | ||
] | ||
|
||
} |
5 changes: 5 additions & 0 deletions
5
integration/monai/examples/mednist/job_templates/fedavg_mednist/info.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
description = "FedAvg using MONAI with in_process Client API" | ||
execution_api_type = "client_api" | ||
controller_type = "server" | ||
} |
11 changes: 11 additions & 0 deletions
11
integration/monai/examples/mednist/job_templates/fedavg_mednist/info.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Job Template Information Card | ||
|
||
## fedavg_mednist | ||
name = "fedavg_mednist" | ||
description = "FedAvg with Scatter and Gather Workflow using pytorch with in_process Client API" | ||
class_name = "ScatterAndGather" | ||
controller_type = "server" | ||
executor_type = "in_process_client_api_executor" | ||
contributor = "NVIDIA" | ||
init_publish_date = "2024-02-8" | ||
last_updated_date = "2024-02-8" # yyyy-mm-dd |
10 changes: 10 additions & 0 deletions
10
integration/monai/examples/mednist/job_templates/fedavg_mednist/meta.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
{ | ||
name = "fedavg_mednist" | ||
resource_spec = {} | ||
deploy_map { | ||
# change deploy map as needed. | ||
app = ["@ALL"] | ||
} | ||
min_clients = 2 | ||
mandatory_clients = [] | ||
} |
Oops, something went wrong.