forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add client controller executor (NVIDIA#2530)
* add client controller executor * address comments * enhance abort, set peer props * remove asserts --------- Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <yuantingh@nvidia.com>
- Loading branch information
1 parent
42afc68
commit 0d71db7
Showing
15 changed files
with
765 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
|
||
# This is the application script which will be invoked. Client can replace this script with user's own training script. | ||
app_script = "train.py" | ||
|
||
# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. | ||
app_config = "" | ||
|
||
# Path to defined PyTorch network | ||
# This assumes that there will be a "net.py" file with class name "Net", please modify accordingly | ||
model_class_path = "net.Net" | ||
|
||
# Client Computing Executors. | ||
executors = [ | ||
{ | ||
# tasks the executors are defined to handle | ||
tasks = [ | ||
"train", | ||
"validate", | ||
"submit_model" | ||
] | ||
|
||
# This particular executor | ||
executor { | ||
|
||
# This is an executor for pytorch + Client API. The underline data exchange is using Pipe. | ||
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" | ||
|
||
args { | ||
|
||
# launcher_id is used to locate the Launcher object in "components" | ||
launcher_id = "launcher" | ||
|
||
# pipe_id is used to locate the Pipe object in "components" | ||
pipe_id = "pipe" | ||
|
||
# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds. | ||
# Please refer to the class docstring for all available arguments | ||
heartbeat_timeout = 60 | ||
|
||
# format of the exchange parameters | ||
params_exchange_format = "pytorch" | ||
|
||
# if the transfer_type is FULL, then it will be sent directly | ||
# if the transfer_type is DIFF, then we will calculate the | ||
# difference VS received parameters and send the difference | ||
params_transfer_type = "FULL" | ||
# if train_with_evaluation is true, the executor will expect | ||
# the custom code need to send back both the trained parameters and the evaluation metric | ||
# otherwise only trained parameters are expected | ||
train_with_evaluation = true | ||
|
||
train_task_name = "train" | ||
evaluate_task_name = "validate" | ||
submit_model_task_name = "submit_model" | ||
} | ||
} | ||
} | ||
{ | ||
# All tasks prefixed with wf_ are routed to this ClientControllerExecutor | ||
tasks = ["wf_*"] | ||
executor { | ||
id = "client_controller_executor" | ||
path = "nvflare.app_common.ccwf.client_controller_executor.ClientControllerExecutor" | ||
# ClientControllerExecutor for running controllers on client-side. | ||
args { | ||
# list of controller ids from components to be run in order | ||
controller_id_list = ["sag_ctl", "cse_ctl"] | ||
task_name_prefix = "wf" | ||
# persistor used to distribute and save final results for clients | ||
persistor_id = "persistor" | ||
} | ||
} | ||
} | ||
] | ||
|
||
# Array of task data filters. If provided, it will control the data from client controller to client executor | ||
# Filter direction (in, out, inout) can be set as since clients send tasks to each other, a task has both a sending (out) and a receiving (in) direction | ||
task_data_filters = [] | ||
|
||
# Array of task result filters. If provided, it will control the data from client executor to client controller | ||
# Filter direction (in, out, inout) can be set as since clients send tasks to each other, a task has both a sending (out) and a receiving (in) direction | ||
task_result_filters = [] | ||
|
||
components = [ | ||
{ | ||
id = "sag_ctl" | ||
path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" | ||
args { | ||
min_clients = 2 | ||
num_rounds = 3 | ||
start_round = 0 | ||
wait_time_after_min_received = 0 | ||
aggregator_id = "aggregator" | ||
persistor_id = "persistor" | ||
shareable_generator_id = "shareable_generator" | ||
train_task_name = "train" | ||
train_timeout = 0 | ||
} | ||
} | ||
{ | ||
id = "cse_ctl", | ||
path = "nvflare.app_common.workflows.cross_site_model_eval.CrossSiteModelEval", | ||
args { | ||
model_locator_id = "model_locator", | ||
submit_model_timeout = 600, | ||
validation_timeout = 6000, | ||
cleanup_models = false | ||
} | ||
} | ||
{ | ||
# component id is "launcher" | ||
id = "launcher" | ||
|
||
# the class path of this component | ||
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" | ||
|
||
args { | ||
# the launcher will invoke the script | ||
script = "python3 custom/{app_script} {app_config} " | ||
# if launch_once is true, the SubprocessLauncher will launch once for the whole job | ||
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server | ||
launch_once = true | ||
} | ||
} | ||
{ | ||
id = "pipe" | ||
|
||
path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe" | ||
|
||
args { | ||
# Mode of the endpoint. A pipe has two endpoints. | ||
# An endpoint can be either the one that initiates communication or the one listening. | ||
# PASSIVE is the one listening. | ||
mode = "PASSIVE" | ||
|
||
# root_path: is the directory location of the parameters exchange. | ||
# You can also set it to an absolute path in your system. | ||
root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}" | ||
} | ||
} | ||
# required components for the client-controlled workflow defined on client-side | ||
{ | ||
id = "persistor" | ||
path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" | ||
args.model.path = "{model_class_path}" | ||
} | ||
{ | ||
id = "shareable_generator" | ||
path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" | ||
args = {} | ||
} | ||
{ | ||
# This is the aggregator that perform the weighted average aggregation. | ||
# the aggregation is "in-time", so it doesn't wait for client results, but aggregates as soon as it received the data. | ||
id = "aggregator" | ||
path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" | ||
args.expected_data_kind = "WEIGHTS" | ||
}, | ||
{ | ||
id = "model_locator" | ||
name = "PTFileModelLocator" | ||
args { | ||
pt_persistor_id = "persistor" | ||
} | ||
}, | ||
{ | ||
# This component is not directly used in Workflow. | ||
# it select the best model based on the incoming global validation metrics. | ||
id = "model_selector" | ||
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" | ||
# need to make sure this "key_metric" match what server side received | ||
args.key_metric = "accuracy" | ||
}, | ||
{ | ||
id = "json_generator" | ||
name = "ValidationJsonGenerator" | ||
args {} | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
|
||
# task data filter: if filters are provided, the filter will filter the data flow out of server to client. | ||
task_data_filters =[] | ||
|
||
# task result filter: if filters are provided, the filter will filter the result flow out of client to server. | ||
task_result_filters = [] | ||
|
||
# This assumes that there will be a "net.py" file with class name "Net". | ||
# If your model code is not in "net.py" and class name is not "Net", please modify here | ||
model_class_path = "net.Net" | ||
|
||
# workflows: Array of workflows the control the Federated Learning workflow lifecycle. | ||
# One can specify multiple workflows. The NVFLARE will run them in the order specified. | ||
workflows = [ | ||
{ | ||
# server-side controller to manage job life cycle and configuration | ||
id = "svr_ctl" | ||
path = "nvflare.app_common.ccwf.server_ctl.ServerSideController" | ||
args { | ||
# the prefix for task names of this workflow | ||
task_name_prefix = "wf" | ||
# the maximum amount of time allowed for a client to miss a status report | ||
max_status_report_interval = 300 | ||
# policy to choose which client to run the controller logic from | ||
starting_client_policy = "random" | ||
# timeout for the ClientControllerExecutor start task, which runs all of the controllers | ||
start_task_timeout = 600 | ||
} | ||
} | ||
] | ||
|
||
# List of components used in the server side workflow. | ||
components = [ | ||
] | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
description = "Client Controller FedAvg and cross-site evaluation with PyTorch" | ||
execution_api_type = "client_api" | ||
controller_type = "client" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Job Template Information Card | ||
|
||
## sag_cse_ccwf_pt | ||
name = "sag_cse_ccwf_pt" | ||
description = "Client Controller FedAvg with scatter & gather workflow and cross-site evaluation with PyTorch" | ||
class_name = "ClientControllerExecutor" | ||
controller_type = "client" | ||
executor_type = "launcher_executor" | ||
contributor = "NVIDIA" | ||
init_publish_date = "2024-04-25" | ||
last_updated_date = "2024-04-25" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
name = "sag_cse_ccwf_pt" | ||
resource_spec {} | ||
min_clients = 2 | ||
deploy_map { | ||
app = [ | ||
"@ALL" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.