diff --git a/integration/monai/examples/README.md b/integration/monai/examples/README.md index 6a92e363d4..60fa3a0005 100644 --- a/integration/monai/examples/README.md +++ b/integration/monai/examples/README.md @@ -1,5 +1,9 @@ # Examples of MONAI-NVFlare Integration +### [Converting MONAI Code to a Federated Learning Setting](./mednist/README.md) +A tutorial to show how simple it can be to run an end-to-end classification pipeline with MONAI +and deploy it in a federated learning setting using NVFlare. + ### [Simulated Federated Learning for 3D spleen CT segmentation](./spleen_ct_segmentation_sim/README.md) An example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train a medical image analysis model using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) diff --git a/integration/monai/examples/mednist/.gitignore b/integration/monai/examples/mednist/.gitignore new file mode 100644 index 0000000000..c964d5fc95 --- /dev/null +++ b/integration/monai/examples/mednist/.gitignore @@ -0,0 +1,3 @@ +# nvflare artifacts for this example +fedavg_workspace +jobs \ No newline at end of file diff --git a/integration/monai/examples/mednist/README.md b/integration/monai/examples/mednist/README.md new file mode 100644 index 0000000000..25693d35fc --- /dev/null +++ b/integration/monai/examples/mednist/README.md @@ -0,0 +1,18 @@ +## Converting MONAI Code to a Federated Learning Setting + +In this tutorial, we will introduce how simple it can be to run an end-to-end classification pipeline with MONAI +and deploy it in a federated learning setting using NVFlare. + +### 1. Standalone training with MONAI +[monai_101.ipynb](./monai_101.ipynb) is based on the [MONAI 101 classification tutorial](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb) and shows each step required in only a few lines of code, including + +- Dataset download +- Data pre-processing +- Define a DenseNet-121 and run training +- Check the results on test dataset + +### 2. Federated learning with MONAI +[monai_101_fl.ipynb](./monai_101_fl.ipynb) shows how we can simply put the code introduced above into a Python script and convert it to running in an FL scenario using NVFlare. + +To achieve this, we utilize the [`FedAvg`](https://arxiv.org/abs/1602.05629) algorithm and NVFlare's [Client +API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type.html#client-api). diff --git a/integration/monai/examples/mednist/code/monai_mednist_train.py b/integration/monai/examples/mednist/code/monai_mednist_train.py new file mode 100644 index 0000000000..86ab525642 --- /dev/null +++ b/integration/monai/examples/mednist/code/monai_mednist_train.py @@ -0,0 +1,169 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MONAI Example adopted from https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb +# +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile +from pathlib import Path + +import numpy as np +import torch +from monai.apps import MedNISTDataset +from monai.config import print_config +from monai.data import DataLoader +from monai.engines import SupervisedTrainer +from monai.handlers import StatsHandler, TensorBoardStatsHandler +from monai.inferers import SimpleInferer +from monai.networks import eval_mode +from monai.networks.nets import densenet121 +from monai.transforms import Compose, EnsureChannelFirstD, LoadImageD, ScaleIntensityD + +# (1) import nvflare client API +import nvflare.client as flare + +# (optional) metrics +from nvflare.client.tracking import SummaryWriter + +print_config() + + +def main(): + # (2) initializes NVFlare client API + flare.init() + + # Setup data directory + directory = os.environ.get("MONAI_DATA_DIRECTORY") + root_dir = tempfile.mkdtemp() if directory is None else directory + print(root_dir) + + # Use MONAI transforms to preprocess data + transform = Compose( + [ + LoadImageD(keys="image", image_only=True), + EnsureChannelFirstD(keys="image"), + ScaleIntensityD(keys="image"), + ] + ) + + # Prepare datasets using MONAI Apps + dataset = MedNISTDataset(root_dir=root_dir, transform=transform, section="training", download=True) + + # Define a network and a supervised trainer + + # If available, we use GPU to speed things up. + DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + max_epochs = 1 # rather than 5 epochs, we run 5 FL rounds with 1 local epoch each. + model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE) + + train_loader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4) + + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + trainer = SupervisedTrainer( + device=torch.device(DEVICE), + max_epochs=max_epochs, + train_data_loader=train_loader, + network=model, + optimizer=torch.optim.Adam(model.parameters(), lr=1e-5), + loss_function=torch.nn.CrossEntropyLoss(), + inferer=SimpleInferer(), + train_handlers=StatsHandler(), + ) + + # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler + summary_writer = SummaryWriter() + train_tensorboard_stats_handler = TensorBoardStatsHandler(summary_writer=summary_writer) + train_tensorboard_stats_handler.attach(trainer) + + # (optional) calculate total steps + steps = max_epochs * len(train_loader) + # Run the training + + while flare.is_running(): + # (3) receives FLModel from NVFlare + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + # (4) loads model from NVFlare and sends it to GPU + trainer.network.load_state_dict(input_model.params) + trainer.network.to(DEVICE) + + trainer.run() + + # (5) wraps evaluation logic into a method to re-use for + # evaluation on both trained and received model + def evaluate(input_weights): + # Create model for evaluation + eval_model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE) + eval_model.load_state_dict(input_weights) + + # Check the prediction on the test dataset + dataset_dir = Path(root_dir, "MedNIST") + class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir()) + testdata = MedNISTDataset( + root_dir=root_dir, transform=transform, section="test", download=False, runtime_cache=True + ) + correct = 0 + total = 0 + max_items_to_print = 10 + _print = 0 + with eval_mode(eval_model): + for item in DataLoader(testdata, batch_size=512, num_workers=0): # changed to do batch processing + prob = np.array(eval_model(item["image"].to(DEVICE)).detach().to("cpu")) + pred = [class_names[p] for p in prob.argmax(axis=1)] + gt = item["class_name"] + # changed the logic a bit from tutorial to compute accuracy on full test set + # but only print for some. + for _gt, _pred in zip(gt, pred): + if _print < max_items_to_print: + print(f"Class prediction is {_pred}. Ground-truth: {_gt}") + _print += 1 + + # compute accuracy + total += 1 + correct += float(_pred == _gt) + + print(f"Accuracy of the network on the {total} test images: {100 * correct // total} %") + return correct / total + + # (6) evaluate on received model for model selection + accuracy = evaluate(input_model.params) + summary_writer.add_scalar(tag="global_model_accuracy", scalar=accuracy, global_step=input_model.current_round) + + # (7) construct trained FL model + output_model = flare.FLModel( + params=trainer.network.cpu().state_dict(), + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + # (8) send model back to NVFlare + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/integration/monai/examples/mednist/figs/mlflow.png b/integration/monai/examples/mednist/figs/mlflow.png new file mode 100644 index 0000000000..662e3637f6 Binary files /dev/null and b/integration/monai/examples/mednist/figs/mlflow.png differ diff --git a/integration/monai/examples/mednist/figs/tb.png b/integration/monai/examples/mednist/figs/tb.png new file mode 100644 index 0000000000..329838ca2e Binary files /dev/null and b/integration/monai/examples/mednist/figs/tb.png differ diff --git a/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_client.conf b/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_client.conf new file mode 100644 index 0000000000..47c9d32c80 --- /dev/null +++ b/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_client.conf @@ -0,0 +1,67 @@ +{ + # version of the configuration + format_version = 2 + + # This is the application script which will be invoked. Client can replace this script with user's own training script. + app_script = "monai_mednist_train.py" + + # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. + # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. + app_config = "" + + # Client Computing Executors. + executors = [ + { + # tasks the executors are defined to handle + tasks = ["train"] + + # This particular executor + executor { + + path = "nvflare.app_opt.pt.in_process_client_api_executor.PTInProcessClientAPIExecutor" + args { + task_script_path = "{app_script}" + task_script_args = "{app_config}" + + # if the transfer_type is FULL, then it will be sent directly + # if the transfer_type is DIFF, then we will calculate the + # difference VS received parameters and send the difference + params_transfer_type = "DIFF" + + # if train_with_evaluation is true, the executor will expect + # the custom code need to send back both the trained parameters and the evaluation metric + # otherwise only trained parameters are expected + train_with_evaluation = true + + # time interval in seconds. Time interval to wait before check if the local task has submitted the result + # if the local task takes long time, you can increase this interval to larger number + # uncomment to overwrite the default, default is 0.5 seconds + result_pull_interval = 0.5 + + # time interval in seconds. Time interval to wait before check if the trainig code has log metric (such as + # Tensorboard log, MLFlow log or Weights & Biases logs. The result will be streanmed to the server side + # then to the corresponding tracking system + # if the log is not needed, you can set this to a larger number + # uncomment to overwrite the default, default is None, which disable the log streaming feature. + log_pull_interval = 0.1 + + } + } + } + ], + + # this defined an array of task data filters. If provided, it will control the data from server controller to client executor + task_data_filters = [] + + # this defined an array of task result filters. If provided, it will control the result from client executor to server controller + task_result_filters = [] + + # define this component that will help relay local metrics log to FL server. + components = [ + { + "id": "event_to_fed", + "name": "ConvertToFedEvent", + "args": {"events_to_convert": ["analytix_log_stats"], "fed_event_prefix": "fed."} + } + ] +} diff --git a/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_server.conf b/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_server.conf new file mode 100644 index 0000000000..e0336318a4 --- /dev/null +++ b/integration/monai/examples/mednist/job_templates/fedavg_mednist/config_fed_server.conf @@ -0,0 +1,94 @@ +{ + # version of the configuration + format_version = 2 + + # task data filter: if filters are provided, the filter will filter the data flow out of server to client. + task_data_filters =[] + + # task result filter: if filters are provided, the filter will filter the result flow out of client to server. + task_result_filters = [] + + # This assumes that there will be a "net.py" file with class name "Net". + # If your model code is not in "net.py" and class name is not "Net", please modify here + model_class_path = "monai.networks.nets.densenet121" + + # densenet arguments + spatial_dims = 2 + in_channels = 1 + out_channels = 6 + + # workflows: Array of workflows the control the Federated Learning workflow lifecycle. + # One can specify multiple workflows. The NVFLARE will run them in the order specified. + workflows = [ + { + # 1st workflow" + id = "scatter_and_gather" + + # name = ScatterAndGather, path is the class path of the ScatterAndGather controller. + path = "nvflare.app_common.workflows.fedavg.FedAvg" + args { + # argument of the ScatterAndGather class. + # min number of clients required for ScatterAndGather controller to move to the next round + # during the workflow cycle. The controller will wait until the min_clients returned from clients + # before move to the next step. + min_clients = 2 + + # number of global round of the training. + num_rounds = 5 + } + } + ] + + # List of components used in the server side workflow. + components = [ + { + # This is the persistence component used in above workflow. + # PTFileModelPersistor is a Pytorch persistor which save/read the model to/from file. + + id = "persistor" + path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" + + # the persitor class take model class as argument + # This imply that the model is initialized from the server-side. + # The initialized model will be broadcast to all the clients to start the training. + args.model.path = "{model_class_path}" + args.model.args.spatial_dims = "{spatial_dims}" + args.model.args.in_channels = "{in_channels}" + args.model.args.out_channels = "{out_channels}" + }, + { + # This component is not directly used in Workflow. + # it select the best model based on the incoming global validation metrics. + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + # need to make sure this "key_metric" match what server side received + args.key_metric = "accuracy" + }, + { + id = "receiver" + path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" + args.events = ["fed.analytix_log_stats"] + }, + { + id = "mlflow_receiver" + path = "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver" + args { + # tracking_uri = "http://0.0.0.0:5000" + tracking_uri = "" + kwargs { + experiment_name = "nvflare-fedavg-mednist-experiment" + run_name = "nvflare-fedavg-mednist-with-mlflow" + experiment_tags { + "mlflow.note.content": "## **NVFlare FedAvg MONAI experiment with MLflow**" + } + run_tags { + "mlflow.note.content" = "## Federated Experiment tracking with MONAI and MLflow \n###" + } + } + artifact_location = "artifacts" + events = ["fed.analytix_log_stats"] + } + } + ] + +} diff --git a/integration/monai/examples/mednist/job_templates/fedavg_mednist/info.conf b/integration/monai/examples/mednist/job_templates/fedavg_mednist/info.conf new file mode 100644 index 0000000000..123dcb6539 --- /dev/null +++ b/integration/monai/examples/mednist/job_templates/fedavg_mednist/info.conf @@ -0,0 +1,5 @@ +{ + description = "FedAvg using MONAI with in_process Client API" + execution_api_type = "client_api" + controller_type = "server" +} \ No newline at end of file diff --git a/integration/monai/examples/mednist/job_templates/fedavg_mednist/info.md b/integration/monai/examples/mednist/job_templates/fedavg_mednist/info.md new file mode 100644 index 0000000000..a35c68ec2d --- /dev/null +++ b/integration/monai/examples/mednist/job_templates/fedavg_mednist/info.md @@ -0,0 +1,11 @@ +# Job Template Information Card + +## fedavg_mednist + name = "fedavg_mednist" + description = "FedAvg with Scatter and Gather Workflow using pytorch with in_process Client API" + class_name = "ScatterAndGather" + controller_type = "server" + executor_type = "in_process_client_api_executor" + contributor = "NVIDIA" + init_publish_date = "2024-02-8" + last_updated_date = "2024-02-8" # yyyy-mm-dd diff --git a/integration/monai/examples/mednist/job_templates/fedavg_mednist/meta.conf b/integration/monai/examples/mednist/job_templates/fedavg_mednist/meta.conf new file mode 100644 index 0000000000..43a7e6a2a3 --- /dev/null +++ b/integration/monai/examples/mednist/job_templates/fedavg_mednist/meta.conf @@ -0,0 +1,10 @@ +{ + name = "fedavg_mednist" + resource_spec = {} + deploy_map { + # change deploy map as needed. + app = ["@ALL"] + } + min_clients = 2 + mandatory_clients = [] +} diff --git a/integration/monai/examples/mednist/monai_101.ipynb b/integration/monai/examples/mednist/monai_101.ipynb new file mode 100644 index 0000000000..0a8b42ba69 --- /dev/null +++ b/integration/monai/examples/mednist/monai_101.ipynb @@ -0,0 +1,286 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "\n", + "MONAI Example adopted from https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb\n", + "\n", + "Copyright (c) MONAI Consortium \n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); \n", + "you may not use this file except in compliance with the License. \n", + "You may obtain a copy of the License at \n", + "    http://www.apache.org/licenses/LICENSE-2.0 \n", + "Unless required by applicable law or agreed to in writing, software \n", + "distributed under the License is distributed on an \"AS IS\" BASIS, \n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n", + "See the License for the specific language governing permissions and \n", + "limitations under the License.\n", + "\n", + "# MONAI 101 tutorial\n", + "\n", + "In this tutorial, we will introduce how simple it can be to run an end-to-end classification pipeline with MONAI.\n", + "\n", + "These steps will be included in this tutorial, and each of them will take only a few lines of code:\n", + "- Dataset download\n", + "- Data pre-processing\n", + "- Define a DenseNet-121 and run training\n", + "- Check the results on test dataset\n", + "\n", + "This tutorial will use about 7GB of GPU memory and 10 minutes to run.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/NVFlare/blob/main/integration/monai/examples/mednist/monai_101.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[ignite, tqdm]\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import numpy as np\n", + "import os\n", + "from pathlib import Path\n", + "import sys\n", + "import tempfile\n", + "import torch\n", + "\n", + "from monai.apps import MedNISTDataset\n", + "from monai.config import print_config\n", + "from monai.data import DataLoader\n", + "from monai.engines import SupervisedTrainer\n", + "from monai.handlers import StatsHandler\n", + "from monai.inferers import SimpleInferer\n", + "from monai.networks import eval_mode\n", + "from monai.networks.nets import densenet121\n", + "from monai.transforms import LoadImageD, EnsureChannelFirstD, ScaleIntensityD, Compose\n", + "\n", + "print_config()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup data directory\n", + "\n", + "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable. \n", + "This allows you to save results and reuse downloads. \n", + "If not specified a temporary directory will be used." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", + "print(root_dir)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use MONAI transforms to preprocess data\n", + "\n", + "Medical images require specialized methods for I/O, preprocessing, and augmentation.\n", + "They often follow specific formats, are handled with specific protocols, and the data arrays are often high-dimensional.\n", + "\n", + "In this example, we will perform image loading, data format verification, and intensity scaling with three `monai.transforms` listed below, and compose a pipeline ready to be used in next steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transform = Compose(\n", + " [\n", + " LoadImageD(keys=\"image\", image_only=True),\n", + " EnsureChannelFirstD(keys=\"image\"),\n", + " ScaleIntensityD(keys=\"image\"),\n", + " ]\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare datasets using MONAI Apps\n", + "\n", + "We use `MedNISTDataset` in MONAI Apps to download a dataset to the specified directory and perform the pre-processing steps in the `monai.transforms` compose.\n", + "\n", + "The MedNIST dataset was gathered from several sets from [TCIA](https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions),\n", + "[the RSNA Bone Age Challenge](http://rsnachallenges.cloudapp.net/competitions/4),\n", + "and [the NIH Chest X-ray dataset](https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest).\n", + "\n", + "The dataset is kindly made available by [Dr. Bradley J. Erickson M.D., Ph.D.](https://www.mayo.edu/research/labs/radiology-informatics/overview) (Department of Radiology, Mayo Clinic)\n", + "under the Creative Commons [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/).\n", + "\n", + "If you use the MedNIST dataset, please acknowledge the source. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = MedNISTDataset(root_dir=root_dir, transform=transform, section=\"training\", download=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define a network and a supervised trainer\n", + "\n", + "To train a model that can perform the classification task, we will use the DenseNet-121 which is known for its performance on the ImageNet dataset.\n", + "\n", + "For a typical supervised training workflow, MONAI provides `SupervisedTrainer` to define the hyper-parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If available, we use GPU to speed things up.\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "max_epochs = 5\n", + "model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(DEVICE)\n", + "\n", + "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n", + "trainer = SupervisedTrainer(\n", + " device=torch.device(DEVICE),\n", + " max_epochs=max_epochs,\n", + " train_data_loader=DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4),\n", + " network=model,\n", + " optimizer=torch.optim.Adam(model.parameters(), lr=1e-5),\n", + " loss_function=torch.nn.CrossEntropyLoss(),\n", + " inferer=SimpleInferer(),\n", + " train_handlers=StatsHandler(),\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.run()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check the prediction on the test dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_dir = Path(root_dir, \"MedNIST\")\n", + "class_names = sorted(f\"{x.name}\" for x in dataset_dir.iterdir() if x.is_dir())\n", + "testdata = MedNISTDataset(root_dir=root_dir, transform=transform, section=\"test\", download=False, runtime_cache=True)\n", + "\n", + "max_items_to_print = 10\n", + "with eval_mode(model):\n", + " for item in DataLoader(testdata, batch_size=1, num_workers=0):\n", + " prob = np.array(model(item[\"image\"].to(DEVICE)).detach().to(\"cpu\"))[0]\n", + " pred = class_names[prob.argmax()]\n", + " gt = item[\"class_name\"][0]\n", + " print(f\"Class prediction is {pred}. Ground-truth: {gt}\")\n", + " max_items_to_print -= 1\n", + " if max_items_to_print == 0:\n", + " break" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/integration/monai/examples/mednist/monai_101_fl.ipynb b/integration/monai/examples/mednist/monai_101_fl.ipynb new file mode 100644 index 0000000000..e63db20b01 --- /dev/null +++ b/integration/monai/examples/mednist/monai_101_fl.ipynb @@ -0,0 +1,233 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "\n", + "MONAI Example adopted from https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb\n", + "\n", + "Copyright (c) MONAI Consortium \n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); \n", + "you may not use this file except in compliance with the License. \n", + "You may obtain a copy of the License at \n", + "    http://www.apache.org/licenses/LICENSE-2.0 \n", + "Unless required by applicable law or agreed to in writing, software \n", + "distributed under the License is distributed on an \"AS IS\" BASIS, \n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n", + "See the License for the specific language governing permissions and \n", + "limitations under the License.\n", + "\n", + "# MONAI 101 tutorial with Federated Learning\n", + "\n", + "In this example, the **server** uses the [`FedAvg`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/workflows/fedavg.py) controller, which performs the following steps:\n", + "1. Initialize the global model. This is achieved through the method `load_model()`\n", + " from the base class\n", + " [`ModelController`](https://github.com/NVIDIA/NVFlare/blob/fa4d00f76848fe4eb356dcde417c136047eeab36/nvflare/app_common/workflows/model_controller.py#L292),\n", + " which relies on the\n", + " [`ModelPersistor`](https://nvflare.readthedocs.io/en/main/glossary.html#persistor). \n", + "2. During each training round, the global model will be sent to the\n", + " list of participating clients to perform a training task. This is\n", + " done using the\n", + " [`send_model()`](https://github.com/NVIDIA/NVFlare/blob/d6827bca96d332adb3402ceceb4b67e876146067/nvflare/app_common/workflows/model_controller.py#L99)\n", + " method under the hood from the `ModelController` base class. Once\n", + " the clients finish their local training, results will be collected\n", + " and sent back to the server as an [`FLModel`](https://nvflare.readthedocs.io/en/main/programming_guide/fl_model.html#flmodel)s.\n", + "3. Results sent by clients will be aggregated based on the\n", + " [`WeightedAggregationHelper`](https://github.com/NVIDIA/NVFlare/blob/fa4d00f76848fe4eb356dcde417c136047eeab36/nvflare/app_common/aggregators/weighted_aggregation_helper.py#L20),\n", + " which weighs the contribution from each client based on the number\n", + " of local training samples. The aggregated updates are\n", + " returned as a new `FLModel`.\n", + "5. After getting the aggregated results, the global model is [updated](https://github.com/NVIDIA/NVFlare/blob/724140e7dc9081eca7a912a818817f89aadfef5d/nvflare/app_common/workflows/fedavg.py#L63).\n", + "6. The last step is to save the updated global model, again through\n", + " the [`ModelPersistor`](https://nvflare.readthedocs.io/en/main/glossary.html#persistor).\n", + "\n", + "The **clients** implement the local training logic using NVFlare's [Client\n", + "API](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type.html#client-api)\n", + "[here](./code/monai_mednist_train.py). The Client API\n", + "allows the user to add minimum `nvflare`-specific codes to turn a typical\n", + "centralized training script to a federated client-side local training\n", + "script.\n", + "1. During local training, each client receives a copy of the global\n", + " model sent by the server using `flare.receive()` API. The received\n", + " global model is an instance of `FLModel`.\n", + "2. A local validation is first performed, where validation metrics\n", + " (accuracy and precision) are streamed to server using the\n", + " [`SummaryWriter`](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.client.tracking.html#nvflare.client.tracking.SummaryWriter). The\n", + " streamed metrics can be loaded and visualized using [TensorBoard](https://www.tensorflow.org/tensorboard) or [MLflow](https://mlflow.org/).\n", + "3. Then, each client performs local training as in the non-federated training [notebook](./monai_101.ipynb). At the end of each FL round, each client then sends the computed results (always in\n", + " `FLModel` format) to the server for aggregation, using the `flare.send()`\n", + " API.\n", + "\n", + "This tutorial will use about 7GB of GPU memory and 10 minutes to run.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/NVFlare/blob/main/integration/monai/examples/mednist/monai_101_fl.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[ignite, tqdm]\"\n", + "!pip install -r requirements.txt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure the NVFlare job templates folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nvflare config -jt ./job_templates\n", + "!nvflare job list_templates" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create job folder\n", + "\n", + "We will use the in-process client API, we prepared a job template ([fedavg_mednist](./job_templates/fedavg_mednist)) based on the [sag_pt in_proc job template](../../../job_templates/sag_pt_in_proc) and run the following command to create the job.\n", + "The `-f` option allows us to customize some options in the template, such as specifying the training script to be used on the clients and initial arguments to the global model, as well as the number of FL rounds." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nvflare job create -force -j ./jobs/fedavg_mednist -w fedavg_mednist -sd ./code/. \\\n", + " -f config_fed_client.conf app_script=monai_mednist_train.py \\\n", + " -f config_fed_server.conf model_class_path=monai.networks.nets.densenet121 \\\n", + " -f config_fed_server.conf spatial_dims=2 \\\n", + " -f config_fed_server.conf in_channels=1 \\\n", + " -f config_fed_server.conf out_channels=6 \\\n", + " -f config_fed_server.conf num_rounds=5" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run FL experiment\n", + "Then we can run it using the NVFlare Simulator for `n=2` clients on `t=2` threads in parallel:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nvflare simulator -n 2 -t 2 ./jobs/fedavg_mednist -w fedavg_workspace" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize the streamed metrics\n", + "The accuracy metrics streamed to the server during training can be visualized using either\n", + "\n", + "1. TensorBoard" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!tensorboard --logdir fedavg_workspace/server/simulate_job/app_server/simulate_job/tb_events" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"TensorBoard" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "or\n", + "\n", + "2. MLflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!mlflow ui --backend-store-uri fedavg_workspace/server/simulate_job/app_server/mlruns" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"MLflow" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/integration/monai/examples/mednist/requirements.txt b/integration/monai/examples/mednist/requirements.txt new file mode 100644 index 0000000000..00e550082a --- /dev/null +++ b/integration/monai/examples/mednist/requirements.txt @@ -0,0 +1,3 @@ +nvflare-nightly @ git+https://github.com/NVIDIA/NVFlare.git@main +tensorboard +mlflow \ No newline at end of file diff --git a/nvflare/client/tracking.py b/nvflare/client/tracking.py index c04b797037..e2ba7e66b7 100644 --- a/nvflare/client/tracking.py +++ b/nvflare/client/tracking.py @@ -64,6 +64,10 @@ def add_scalars(self, tag: str, scalars: dict, global_step: Optional[int] = None **kwargs, ) + def flush(self): + """Skip flushing which would normally write the event file to disk""" + pass + class WandBWriter: """WandBWriter mimics the usage of weights and biases.