Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster committed May 2, 2024
1 parent ccf4d1a commit 1cce018
Show file tree
Hide file tree
Showing 11 changed files with 29 additions and 13 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Job Template Information Card

## sag_cse_cc_pt
name = "sag_cse_cc_pt"
## sag_cse_ccwf_pt
name = "sag_cse_ccwf_pt"
description = "Client Controller FedAvg with scatter & gather workflow and cross-site evaluation with PyTorch"
class_name = "ClientControllerExecutor"
controller_type = "client"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name = "sag_cse_cc_pt"
name = "sag_cse_ccwf_pt"
resource_spec {}
min_clients = 2
deploy_map {
Expand Down
2 changes: 1 addition & 1 deletion nvflare/apis/controller_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
name (str): name of the task
data (Shareable): data of the task
props: Any additional properties of the task
timeout: How long this task will last. If == 0, the task never time out.
timeout: How long this task will last. If == 0, the task never time out (WFCommServer-> never time out, WFCommClient-> time out after `max_task_timeout`).
before_task_sent_cb: If provided, this callback would be called before controller sends the tasks to clients.
It needs to follow the before_task_sent_cb_signature.
after_task_sent_cb: If provided, this callback would be called after controller sends the tasks to clients.
Expand Down
13 changes: 10 additions & 3 deletions nvflare/apis/impl/wf_comm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,26 @@
from nvflare.apis.signal import Signal
from nvflare.apis.utils.task_utils import apply_filters
from nvflare.apis.wf_comm_spec import WFCommSpec
from nvflare.app_common.ccwf.common import Constant
from nvflare.private.fed.utils.fed_utils import get_target_names
from nvflare.private.privacy_manager import Scope
from nvflare.security.logging import secure_format_exception

MAX_TASK_TIMEOUT = 3600


class WFCommClient(FLComponent, WFCommSpec):
def __init__(
self,
max_task_timeout: int = Constant.MAX_TASK_TIMEOUT,
) -> None:
"""Communicator using aux channel communication.
Args:
max_task_timeout (int, optional): Maximum task timeout when `task.timeout` is set to 0. Defaults to 3600.
"""
super().__init__()
self.task_data_filters = {}
self.task_result_filters = {}
self.max_task_timeout = max_task_timeout

def broadcast(
self,
Expand Down Expand Up @@ -104,8 +110,9 @@ def broadcast_and_wait(
raise ValueError(f"The task timeout must >= 0. But got {task.timeout}")

if task.timeout == 0:
task.timeout = MAX_TASK_TIMEOUT
task.timeout = self.max_task_timeout

# Note: set request here since task.data can be modified by user callback before_task_sent_cb
request = task.data

request.set_header(ReservedKey.TASK_NAME, task.name)
Expand Down
10 changes: 7 additions & 3 deletions nvflare/app_common/ccwf/client_controller_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
task_name_prefix: str = "",
persistor_id=AppConstants.DEFAULT_PERSISTOR_ID,
final_result_ack_timeout=Constant.FINAL_RESULT_ACK_TIMEOUT,
max_task_timeout: int = Constant.MAX_TASK_TIMEOUT,
):
"""
ClientControllerExecutor for running controllers on client-side using WFCommClient.
Expand All @@ -47,17 +48,20 @@ def __init__(
task_name_prefix: prefix of task names. All CCWF task names are prefixed with this.
persistor_id: ID of the persistor component
final_result_ack_timeout: timeout for sending final result to participating clients
max_task_timeout: Maximum task timeout for Controllers using WFCommClient when `task.timeout` is set to 0. Defaults to 3600.
"""
check_number_range("final_result_ack_timeout", final_result_ack_timeout, min_value=1.0)

Executor.__init__(self)
self.controller_id_list = controller_id_list
self.task_name_prefix = task_name_prefix
self.persistor_id = persistor_id
self.final_result_ack_timeout = final_result_ack_timeout
self.max_task_timeout = max_task_timeout

self.start_task_name = make_task_name(task_name_prefix, Constant.BASENAME_START)
self.configure_task_name = make_task_name(task_name_prefix, Constant.BASENAME_CONFIG)
self.report_final_result_task_name = make_task_name(task_name_prefix, Constant.BASENAME_REPORT_FINAL_RESULT)
self.final_result_ack_timeout = final_result_ack_timeout
self.persistor_id = persistor_id

self.persistor = None

Expand Down Expand Up @@ -111,7 +115,7 @@ def start_run(self, fl_ctx: FLContext):
def initialize_controller(self, controller_id, fl_ctx):
controller = self.engine.get_component(controller_id)

comm = WFCommClient()
comm = WFCommClient(max_task_timeout=self.max_task_timeout)
controller.set_communicator(comm)
controller.config = self.config
controller.initialize(fl_ctx)
Expand Down
1 change: 1 addition & 0 deletions nvflare/app_common/ccwf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class Constant:
LEARN_TASK_ABORT_TIMEOUT = 5.0
FINAL_RESULT_ACK_TIMEOUT = 10
GET_MODEL_TIMEOUT = 10
MAX_TASK_TIMEOUT = 3600

PROP_KEY_TRAIN_CLIENTS = "cwf.train_clients"

Expand Down
6 changes: 3 additions & 3 deletions nvflare/app_common/ccwf/server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
Constructor
Args:
num_rounds - the number of rounds to be performed. This is a workflow config parameter.
num_rounds - the number of rounds to be performed. This is a workflow config parameter. Defaults to 1.
start_round - the starting round number. This is a workflow config parameter.
task_name_prefix - the prefix for task names of this workflow.
The workflow requires multiple tasks (e.g. config and start) between the server controller and the client.
Expand All @@ -99,7 +99,8 @@ def __init__(
starting_client - name of the starting client.
starting_client_policy - how to determine the starting client if the name is not explicitly specified.
Possible values are:
ANY - any one of the participating clients (randomly chosen)
ANY - any one of the participating clients (the first client)
RANDOM - a random client
EMPTY - no starting client
DISALLOW - does not allow implicit - starting_client must be explicitly specified
start_task_timeout - how long to wait for the starting client to finish the “start” task.
Expand Down Expand Up @@ -184,7 +185,6 @@ def start_controller(self, fl_ctx: FLContext):
allow_none=False,
)

random.shuffle(self.participating_clients)
self.log_info(fl_ctx, f"Using participating clients: {self.participating_clients}")
self.starting_client = validate_candidate(
var_name="starting_client",
Expand Down
4 changes: 4 additions & 0 deletions nvflare/fuel/utils/validation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def validate_candidates(var_name: str, candidates, base: list, default_policy: s
1. Not explicitly specified (Python object None or empty list [])
In this case, the default_policy decides the final result:
- ANY: returns a list that contains a single item from the base
- RANDOM: returns a list that contains a random item from the base
- EMPTY: returns an empty list
- ALL: returns the base list
- DISALLOW: raise exception - candidates must be explicitly specified
Expand Down Expand Up @@ -192,6 +193,8 @@ def validate_candidates(var_name: str, candidates, base: list, default_policy: s
return base
elif default_policy == DefaultValuePolicy.DISALLOW:
raise ValueError(f"invalid value '{candidates}' in '{var_name}': it must be subset of {base}")
elif default_policy == DefaultValuePolicy.RANDOM:
return [random.choice(base)]
else:
# any
return [base[0]]
Expand Down Expand Up @@ -225,6 +228,7 @@ def validate_candidate(var_name: str, candidate, base: list, default_policy: str
1. Not explicitly specified (Python object None or empty string)
In this case, the default_policy decides the final result:
- ANY: returns the first item from the base
- RANDOM: returns a random item from the base
- EMPTY: returns an empty str
- ALL or DISALLOW: raise exception - candidate must be explicitly specified
Expand Down

0 comments on commit 1cce018

Please sign in to comment.