Skip to content

Commit

Permalink
MONAI mednist example (NVIDIA#2532)
Browse files Browse the repository at this point in the history
* add monai notebook

* add training script

* update example

* update notebook

* use job template

* call init later

* swith back

* add gitignore

* update notebooks

* add readmes

* send received model to GPU

* use monai tb stats handler

* formatting
  • Loading branch information
holgerroth authored and MinghuiChen43 committed May 6, 2024
1 parent f251451 commit 46b8d2a
Show file tree
Hide file tree
Showing 15 changed files with 907 additions and 0 deletions.
4 changes: 4 additions & 0 deletions integration/monai/examples/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Examples of MONAI-NVFlare Integration

### [Converting MONAI Code to a Federated Learning Setting](./mednist/README.md)
A tutorial to show how simple it can be to run an end-to-end classification pipeline with MONAI
and deploy it in a federated learning setting using NVFlare.

### [Simulated Federated Learning for 3D spleen CT segmentation](./spleen_ct_segmentation_sim/README.md)
An example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)
to train a medical image analysis model using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629))
Expand Down
3 changes: 3 additions & 0 deletions integration/monai/examples/mednist/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# nvflare artifacts for this example
fedavg_workspace
jobs
18 changes: 18 additions & 0 deletions integration/monai/examples/mednist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## Converting MONAI Code to a Federated Learning Setting

In this tutorial, we will introduce how simple it can be to run an end-to-end classification pipeline with MONAI
and deploy it in a federated learning setting using NVFlare.

### 1. Standalone training with MONAI
[monai_101.ipynb](./monai_101.ipynb) is based on the [MONAI 101 classification tutorial](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb) and shows each step required in only a few lines of code, including

- Dataset download
- Data pre-processing
- Define a DenseNet-121 and run training
- Check the results on test dataset

### 2. Federated learning with MONAI
[monai_101_fl.ipynb](./monai_101_fl.ipynb) shows how we can simply put the code introduced above into a Python script and convert it to running in an FL scenario using NVFlare.

To achieve this, we utilize the [`FedAvg`](https://arxiv.org/abs/1602.05629) algorithm and NVFlare's [Client
API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type.html#client-api).
169 changes: 169 additions & 0 deletions integration/monai/examples/mednist/code/monai_mednist_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MONAI Example adopted from https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb
#
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import tempfile
from pathlib import Path

import numpy as np
import torch
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import DataLoader
from monai.engines import SupervisedTrainer
from monai.handlers import StatsHandler, TensorBoardStatsHandler
from monai.inferers import SimpleInferer
from monai.networks import eval_mode
from monai.networks.nets import densenet121
from monai.transforms import Compose, EnsureChannelFirstD, LoadImageD, ScaleIntensityD

# (1) import nvflare client API
import nvflare.client as flare

# (optional) metrics
from nvflare.client.tracking import SummaryWriter

print_config()


def main():
# (2) initializes NVFlare client API
flare.init()

# Setup data directory
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

# Use MONAI transforms to preprocess data
transform = Compose(
[
LoadImageD(keys="image", image_only=True),
EnsureChannelFirstD(keys="image"),
ScaleIntensityD(keys="image"),
]
)

# Prepare datasets using MONAI Apps
dataset = MedNISTDataset(root_dir=root_dir, transform=transform, section="training", download=True)

# Define a network and a supervised trainer

# If available, we use GPU to speed things up.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

max_epochs = 1 # rather than 5 epochs, we run 5 FL rounds with 1 local epoch each.
model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE)

train_loader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4)

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
trainer = SupervisedTrainer(
device=torch.device(DEVICE),
max_epochs=max_epochs,
train_data_loader=train_loader,
network=model,
optimizer=torch.optim.Adam(model.parameters(), lr=1e-5),
loss_function=torch.nn.CrossEntropyLoss(),
inferer=SimpleInferer(),
train_handlers=StatsHandler(),
)

# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
summary_writer = SummaryWriter()
train_tensorboard_stats_handler = TensorBoardStatsHandler(summary_writer=summary_writer)
train_tensorboard_stats_handler.attach(trainer)

# (optional) calculate total steps
steps = max_epochs * len(train_loader)
# Run the training

while flare.is_running():
# (3) receives FLModel from NVFlare
input_model = flare.receive()
print(f"current_round={input_model.current_round}")

# (4) loads model from NVFlare and sends it to GPU
trainer.network.load_state_dict(input_model.params)
trainer.network.to(DEVICE)

trainer.run()

# (5) wraps evaluation logic into a method to re-use for
# evaluation on both trained and received model
def evaluate(input_weights):
# Create model for evaluation
eval_model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE)
eval_model.load_state_dict(input_weights)

# Check the prediction on the test dataset
dataset_dir = Path(root_dir, "MedNIST")
class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir())
testdata = MedNISTDataset(
root_dir=root_dir, transform=transform, section="test", download=False, runtime_cache=True
)
correct = 0
total = 0
max_items_to_print = 10
_print = 0
with eval_mode(eval_model):
for item in DataLoader(testdata, batch_size=512, num_workers=0): # changed to do batch processing
prob = np.array(eval_model(item["image"].to(DEVICE)).detach().to("cpu"))
pred = [class_names[p] for p in prob.argmax(axis=1)]
gt = item["class_name"]
# changed the logic a bit from tutorial to compute accuracy on full test set
# but only print for some.
for _gt, _pred in zip(gt, pred):
if _print < max_items_to_print:
print(f"Class prediction is {_pred}. Ground-truth: {_gt}")
_print += 1

# compute accuracy
total += 1
correct += float(_pred == _gt)

print(f"Accuracy of the network on the {total} test images: {100 * correct // total} %")
return correct / total

# (6) evaluate on received model for model selection
accuracy = evaluate(input_model.params)
summary_writer.add_scalar(tag="global_model_accuracy", scalar=accuracy, global_step=input_model.current_round)

# (7) construct trained FL model
output_model = flare.FLModel(
params=trainer.network.cpu().state_dict(),
metrics={"accuracy": accuracy},
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)
# (8) send model back to NVFlare
flare.send(output_model)


if __name__ == "__main__":
main()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added integration/monai/examples/mednist/figs/tb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
# version of the configuration
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "monai_mednist_train.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = ""

# Client Computing Executors.
executors = [
{
# tasks the executors are defined to handle
tasks = ["train"]

# This particular executor
executor {

path = "nvflare.app_opt.pt.in_process_client_api_executor.PTInProcessClientAPIExecutor"
args {
task_script_path = "{app_script}"
task_script_args = "{app_config}"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "DIFF"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true

# time interval in seconds. Time interval to wait before check if the local task has submitted the result
# if the local task takes long time, you can increase this interval to larger number
# uncomment to overwrite the default, default is 0.5 seconds
result_pull_interval = 0.5

# time interval in seconds. Time interval to wait before check if the trainig code has log metric (such as
# Tensorboard log, MLFlow log or Weights & Biases logs. The result will be streanmed to the server side
# then to the corresponding tracking system
# if the log is not needed, you can set this to a larger number
# uncomment to overwrite the default, default is None, which disable the log streaming feature.
log_pull_interval = 0.1

}
}
}
],

# this defined an array of task data filters. If provided, it will control the data from server controller to client executor
task_data_filters = []

# this defined an array of task result filters. If provided, it will control the result from client executor to server controller
task_result_filters = []

# define this component that will help relay local metrics log to FL server.
components = [
{
"id": "event_to_fed",
"name": "ConvertToFedEvent",
"args": {"events_to_convert": ["analytix_log_stats"], "fed_event_prefix": "fed."}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
{
# version of the configuration
format_version = 2

# task data filter: if filters are provided, the filter will filter the data flow out of server to client.
task_data_filters =[]

# task result filter: if filters are provided, the filter will filter the result flow out of client to server.
task_result_filters = []

# This assumes that there will be a "net.py" file with class name "Net".
# If your model code is not in "net.py" and class name is not "Net", please modify here
model_class_path = "monai.networks.nets.densenet121"

# densenet arguments
spatial_dims = 2
in_channels = 1
out_channels = 6

# workflows: Array of workflows the control the Federated Learning workflow lifecycle.
# One can specify multiple workflows. The NVFLARE will run them in the order specified.
workflows = [
{
# 1st workflow"
id = "scatter_and_gather"

# name = ScatterAndGather, path is the class path of the ScatterAndGather controller.
path = "nvflare.app_common.workflows.fedavg.FedAvg"
args {
# argument of the ScatterAndGather class.
# min number of clients required for ScatterAndGather controller to move to the next round
# during the workflow cycle. The controller will wait until the min_clients returned from clients
# before move to the next step.
min_clients = 2

# number of global round of the training.
num_rounds = 5
}
}
]

# List of components used in the server side workflow.
components = [
{
# This is the persistence component used in above workflow.
# PTFileModelPersistor is a Pytorch persistor which save/read the model to/from file.

id = "persistor"
path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor"

# the persitor class take model class as argument
# This imply that the model is initialized from the server-side.
# The initialized model will be broadcast to all the clients to start the training.
args.model.path = "{model_class_path}"
args.model.args.spatial_dims = "{spatial_dims}"
args.model.args.in_channels = "{in_channels}"
args.model.args.out_channels = "{out_channels}"
},
{
# This component is not directly used in Workflow.
# it select the best model based on the incoming global validation metrics.
id = "model_selector"
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector"
# need to make sure this "key_metric" match what server side received
args.key_metric = "accuracy"
},
{
id = "receiver"
path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver"
args.events = ["fed.analytix_log_stats"]
},
{
id = "mlflow_receiver"
path = "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver"
args {
# tracking_uri = "http://0.0.0.0:5000"
tracking_uri = ""
kwargs {
experiment_name = "nvflare-fedavg-mednist-experiment"
run_name = "nvflare-fedavg-mednist-with-mlflow"
experiment_tags {
"mlflow.note.content": "## **NVFlare FedAvg MONAI experiment with MLflow**"
}
run_tags {
"mlflow.note.content" = "## Federated Experiment tracking with MONAI and MLflow \n###"
}
}
artifact_location = "artifacts"
events = ["fed.analytix_log_stats"]
}
}
]

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
description = "FedAvg using MONAI with in_process Client API"
execution_api_type = "client_api"
controller_type = "server"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Job Template Information Card

## fedavg_mednist
name = "fedavg_mednist"
description = "FedAvg with Scatter and Gather Workflow using pytorch with in_process Client API"
class_name = "ScatterAndGather"
controller_type = "server"
executor_type = "in_process_client_api_executor"
contributor = "NVIDIA"
init_publish_date = "2024-02-8"
last_updated_date = "2024-02-8" # yyyy-mm-dd
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
name = "fedavg_mednist"
resource_spec = {}
deploy_map {
# change deploy map as needed.
app = ["@ALL"]
}
min_clients = 2
mandatory_clients = []
}
Loading

0 comments on commit 46b8d2a

Please sign in to comment.