Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client controller executor #2530

Merged
merged 6 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions job_templates/sag_cse_ccwf_pt/config_fed_client.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
{
# version of the configuration
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "train.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = ""

# Path to defined PyTorch network
# This assumes that there will be a "net.py" file with class name "Net", please modify accordingly
model_class_path = "net.Net"

# Client Computing Executors.
executors = [
{
# tasks the executors are defined to handle
tasks = [
"train",
"validate",
"submit_model"
]

# This particular executor
executor {

# This is an executor for pytorch + Client API. The underline data exchange is using Pipe.
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"

args {

# launcher_id is used to locate the Launcher object in "components"
launcher_id = "launcher"

# pipe_id is used to locate the Pipe object in "components"
pipe_id = "pipe"

# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
# Please refer to the class docstring for all available arguments
heartbeat_timeout = 60

# format of the exchange parameters
params_exchange_format = "pytorch"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "FULL"
# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true

train_task_name = "train"
evaluate_task_name = "validate"
submit_model_task_name = "submit_model"
}
}
}
{
# All tasks prefixed with wf_ are routed to this ClientControllerExecutor
tasks = ["wf_*"]
executor {
id = "client_controller_executor"
path = "nvflare.app_common.ccwf.client_controller_executor.ClientControllerExecutor"
# ClientControllerExecutor for running controllers on client-side.
args {
# list of controller ids from components to be run in order
controller_id_list = ["sag_ctl", "cse_ctl"]
task_name_prefix = "wf"
# persistor used to distribute and save final results for clients
persistor_id = "persistor"
}
}
}
]

# Array of task data filters. If provided, it will control the data from client controller to client executor
# Filter direction (in, out, inout) can be set as since clients send tasks to each other, a task has both a sending (out) and a receiving (in) direction
task_data_filters = []

# Array of task result filters. If provided, it will control the data from client executor to client controller
# Filter direction (in, out, inout) can be set as since clients send tasks to each other, a task has both a sending (out) and a receiving (in) direction
task_result_filters = []

components = [
{
id = "sag_ctl"
path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather"
args {
min_clients = 2
num_rounds = 3
start_round = 0
wait_time_after_min_received = 0
aggregator_id = "aggregator"
persistor_id = "persistor"
shareable_generator_id = "shareable_generator"
train_task_name = "train"
train_timeout = 0
}
}
{
id = "cse_ctl",
path = "nvflare.app_common.workflows.cross_site_model_eval.CrossSiteModelEval",
args {
model_locator_id = "model_locator",
submit_model_timeout = 600,
validation_timeout = 6000,
cleanup_models = false
}
}
{
# component id is "launcher"
id = "launcher"

# the class path of this component
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"

args {
# the launcher will invoke the script
script = "python3 custom/{app_script} {app_config} "
# if launch_once is true, the SubprocessLauncher will launch once for the whole job
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
launch_once = true
}
}
{
id = "pipe"

path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe"

args {
# Mode of the endpoint. A pipe has two endpoints.
# An endpoint can be either the one that initiates communication or the one listening.
# PASSIVE is the one listening.
mode = "PASSIVE"

# root_path: is the directory location of the parameters exchange.
# You can also set it to an absolute path in your system.
root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}"
}
}
# required components for the client-controlled workflow defined on client-side
{
id = "persistor"
path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor"
args.model.path = "{model_class_path}"
}
{
id = "shareable_generator"
path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator"
args = {}
}
{
# This is the aggregator that perform the weighted average aggregation.
# the aggregation is "in-time", so it doesn't wait for client results, but aggregates as soon as it received the data.
id = "aggregator"
path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator"
args.expected_data_kind = "WEIGHTS"
},
{
id = "model_locator"
name = "PTFileModelLocator"
args {
pt_persistor_id = "persistor"
}
},
{
# This component is not directly used in Workflow.
# it select the best model based on the incoming global validation metrics.
id = "model_selector"
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector"
# need to make sure this "key_metric" match what server side received
args.key_metric = "accuracy"
},
{
id = "json_generator"
name = "ValidationJsonGenerator"
args {}
}
]
}
39 changes: 39 additions & 0 deletions job_templates/sag_cse_ccwf_pt/config_fed_server.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
# version of the configuration
format_version = 2

# task data filter: if filters are provided, the filter will filter the data flow out of server to client.
task_data_filters =[]

# task result filter: if filters are provided, the filter will filter the result flow out of client to server.
task_result_filters = []

# This assumes that there will be a "net.py" file with class name "Net".
# If your model code is not in "net.py" and class name is not "Net", please modify here
model_class_path = "net.Net"

# workflows: Array of workflows the control the Federated Learning workflow lifecycle.
# One can specify multiple workflows. The NVFLARE will run them in the order specified.
workflows = [
{
# server-side controller to manage job life cycle and configuration
id = "svr_ctl"
path = "nvflare.app_common.ccwf.server_ctl.ServerSideController"
args {
# the prefix for task names of this workflow
task_name_prefix = "wf"
# the maximum amount of time allowed for a client to miss a status report
max_status_report_interval = 300
# policy to choose which client to run the controller logic from
starting_client_policy = "random"
# timeout for the ClientControllerExecutor start task, which runs all of the controllers
start_task_timeout = 600
}
}
]

# List of components used in the server side workflow.
components = [
]

}
5 changes: 5 additions & 0 deletions job_templates/sag_cse_ccwf_pt/info.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
description = "Client Controller FedAvg and cross-site evaluation with PyTorch"
execution_api_type = "client_api"
controller_type = "client"
}
11 changes: 11 additions & 0 deletions job_templates/sag_cse_ccwf_pt/info.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Job Template Information Card

## sag_cse_ccwf_pt
name = "sag_cse_ccwf_pt"
description = "Client Controller FedAvg with scatter & gather workflow and cross-site evaluation with PyTorch"
class_name = "ClientControllerExecutor"
controller_type = "client"
executor_type = "launcher_executor"
contributor = "NVIDIA"
init_publish_date = "2024-04-25"
last_updated_date = "2024-04-25"
8 changes: 8 additions & 0 deletions job_templates/sag_cse_ccwf_pt/meta.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name = "sag_cse_ccwf_pt"
resource_spec {}
min_clients = 2
deploy_map {
app = [
"@ALL"
]
}
2 changes: 1 addition & 1 deletion nvflare/apis/controller_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
name (str): name of the task
data (Shareable): data of the task
props: Any additional properties of the task
timeout: How long this task will last. If == 0, the task never time out.
timeout: How long this task will last. If == 0, the task never time out (WFCommServer-> never time out, WFCommClient-> time out after `max_task_timeout`).
before_task_sent_cb: If provided, this callback would be called before controller sends the tasks to clients.
It needs to follow the before_task_sent_cb_signature.
after_task_sent_cb: If provided, this callback would be called after controller sends the tasks to clients.
Expand Down
32 changes: 22 additions & 10 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def broadcast(
min_responses: int = 1,
wait_time_after_min_received: int = 0,
):
self.communicator.broadcast(task, fl_ctx, targets, min_responses, wait_time_after_min_received)
return self.communicator.broadcast(task, fl_ctx, targets, min_responses, wait_time_after_min_received)

def broadcast_and_wait(
self,
Expand All @@ -71,12 +71,12 @@ def broadcast_and_wait(
wait_time_after_min_received: int = 0,
abort_signal: Optional[Signal] = None,
):
self.communicator.broadcast_and_wait(
return self.communicator.broadcast_and_wait(
task, fl_ctx, targets, min_responses, wait_time_after_min_received, abort_signal
)

def broadcast_forever(self, task: Task, fl_ctx: FLContext, targets: Union[List[Client], List[str], None] = None):
self.communicator.broadcast_forever(task, fl_ctx, targets)
return self.communicator.broadcast_forever(task, fl_ctx, targets)

def send(
self,
Expand All @@ -86,7 +86,7 @@ def send(
send_order: SendOrder = SendOrder.SEQUENTIAL,
task_assignment_timeout: int = 0,
):
self.communicator.send(task, fl_ctx, targets, send_order, task_assignment_timeout)
return self.communicator.send(task, fl_ctx, targets, send_order, task_assignment_timeout)

def send_and_wait(
self,
Expand All @@ -97,7 +97,7 @@ def send_and_wait(
task_assignment_timeout: int = 0,
abort_signal: Signal = None,
):
self.communicator.send_and_wait(task, fl_ctx, targets, send_order, task_assignment_timeout, abort_signal)
return self.communicator.send_and_wait(task, fl_ctx, targets, send_order, task_assignment_timeout, abort_signal)

def relay(
self,
Expand All @@ -109,7 +109,7 @@ def relay(
task_result_timeout: int = 0,
dynamic_targets: bool = True,
):
self.communicator.relay(
return self.communicator.relay(
task, fl_ctx, targets, send_order, task_assignment_timeout, task_result_timeout, dynamic_targets
)

Expand All @@ -124,7 +124,7 @@ def relay_and_wait(
dynamic_targets: bool = True,
abort_signal: Optional[Signal] = None,
):
self.communicator.relay_and_wait(
return self.communicator.relay_and_wait(
task,
fl_ctx,
targets,
Expand All @@ -136,15 +136,22 @@ def relay_and_wait(
)

def get_num_standing_tasks(self) -> int:
return self.communicator.get_num_standing_tasks()
try:
return self.communicator.get_num_standing_tasks()
except Exception as e:
self.logger.warning(f"get_num_standing_tasks() is not supported by {self.communicator}: {e}")
SYangster marked this conversation as resolved.
Show resolved Hide resolved
return None

def cancel_task(
self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None
):
self.communicator.cancel_task(task, completion_status, fl_ctx)

def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None):
self.communicator.cancel_all_tasks(completion_status, fl_ctx)
try:
self.communicator.cancel_all_tasks(completion_status, fl_ctx)
except Exception as e:
self.log_warning(fl_ctx, f"cancel_all_tasks() is not supported by {self.communicator}: {e}")
SYangster marked this conversation as resolved.
Show resolved Hide resolved

def get_client_disconnect_time(self, client_name):
"""Get the time when the client is deemed disconnected.
Expand All @@ -157,4 +164,9 @@ def get_client_disconnect_time(self, client_name):
"""
if not self.communicator:
return None
return self.communicator.get_client_disconnect_time(client_name)

try:
return self.communicator.get_client_disconnect_time(client_name)
except Exception as e:
self.logger.warning(f"get_client_disconnect_time() is not supported by {self.communicator}: {e}")
return None
Loading
Loading