From 7a053319d085886a00244552ffce740e35e1ae58 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Mon, 18 Mar 2024 11:15:09 -0700 Subject: [PATCH] add communicator to handlers --- nvflare/apis/impl/controller.py | 3 --- nvflare/apis/impl/wf_comm_server.py | 3 ++- .../app_common/psi/dh_psi/dh_psi_workflow.py | 25 +++++++++++++------ .../workflows/broadcast_operator.py | 10 +++++--- nvflare/app_common/workflows/cyclic_ctl.py | 1 - nvflare/private/fed/client/client_runner.py | 2 +- .../private/fed/server/server_json_config.py | 17 +++++++------ nvflare/private/fed/server/server_runner.py | 2 +- nvflare/widgets/info_collector.py | 3 ++- 9 files changed, 39 insertions(+), 27 deletions(-) diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index d994e20d44..6091ac23f1 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -145,6 +145,3 @@ def cancel_task( def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None): self.communicator.cancel_all_tasks(completion_status, fl_ctx) - - def handle_event(self, event_type: str, fl_ctx: FLContext): - self.communicator.handle_event(event_type, fl_ctx) diff --git a/nvflare/apis/impl/wf_comm_server.py b/nvflare/apis/impl/wf_comm_server.py index b1d1a1877f..661257d4df 100644 --- a/nvflare/apis/impl/wf_comm_server.py +++ b/nvflare/apis/impl/wf_comm_server.py @@ -148,8 +148,9 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): event_type (str): all event types, including AppEventType and EventType fl_ctx (FLContext): FLContext information with current event type """ - if event_type == InfoCollector.EVENT_TYPE_GET_STATS: + if event_type == InfoCollector.EVENT_TYPE_SET_STATS: self._set_stats(fl_ctx) + self.fire_event(InfoCollector.EVENT_TYPE_GET_STATS, fl_ctx) elif event_type == EventType.JOB_DEAD: client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME) with self._dead_clients_lock: diff --git a/nvflare/app_common/psi/dh_psi/dh_psi_workflow.py b/nvflare/app_common/psi/dh_psi/dh_psi_workflow.py index c3b6b08e56..95212ff73f 100644 --- a/nvflare/app_common/psi/dh_psi/dh_psi_workflow.py +++ b/nvflare/app_common/psi/dh_psi/dh_psi_workflow.py @@ -163,7 +163,7 @@ def pairwise_setup(self, ordered_sites: List[SiteSize]): bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.SETUP_MSG] for site_name in results} @@ -181,7 +181,7 @@ def pairwise_requests(self, ordered_sites: List[SiteSize], setup_msgs: Dict[str, bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.REQUEST_MSG] for site_name in results} @@ -199,7 +199,7 @@ def pairwise_responses(self, ordered_sites: List[SiteSize], request_msgs: Dict[s bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.RESPONSE_MSG] for site_name in results} @@ -217,7 +217,7 @@ def pairwise_intersect(self, ordered_sites: List[SiteSize], response_msg: Dict[s bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) return {site_name: results[site_name].data[PSIConst.ITEMS_SIZE] for site_name in results} @@ -279,7 +279,7 @@ def calculate_intersections(self, response_msg) -> Dict[str, int]: task_inputs[client_name] = inputs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) intersects = {client_name: results[client_name].data[PSIConst.ITEMS_SIZE] for client_name in results} @@ -292,7 +292,11 @@ def process_requests(self, s: SiteSize, request_msgs: Dict[str, str]) -> Dict[st task_inputs[PSIConst.REQUEST_MSG_SET] = request_msgs bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.broadcast_and_wait( - task_name=self.task_name, task_input=task_inputs, targets=[s.name], abort_signal=self.abort_signal + task_name=self.task_name, + task_input=task_inputs, + fl_ctx=self.fl_ctx, + targets=[s.name], + abort_signal=self.abort_signal, ) dxo = results[s.name] @@ -309,7 +313,7 @@ def create_requests(self, site_setup_msgs) -> Dict[str, str]: bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.multicasts_and_wait( - task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal + task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal ) request_msgs = {client_name: results[client_name].data[PSIConst.REQUEST_MSG] for client_name in results} return request_msgs @@ -335,6 +339,7 @@ def prepare_sites(self, abort_signal): results = bop.broadcast_and_wait( task_name=self.task_name, task_input=inputs, + fl_ctx=self.fl_ctx, targets=targets, min_responses=min_responses, abort_signal=abort_signal, @@ -352,7 +357,11 @@ def prepare_setup_messages(self, s: SiteSize, other_site_sizes: Set[int]) -> Dic inputs[PSIConst.ITEMS_SIZE_SET] = other_site_sizes bop = BroadcastAndWait(self.fl_ctx, self.controller) results = bop.broadcast_and_wait( - task_name=self.task_name, task_input=inputs, targets=[s.name], abort_signal=self.abort_signal + task_name=self.task_name, + task_input=inputs, + fl_ctx=self.fl_ctx, + targets=[s.name], + abort_signal=self.abort_signal, ) dxo = results[s.name] return dxo.data[PSIConst.SETUP_MSG] diff --git a/nvflare/app_common/workflows/broadcast_operator.py b/nvflare/app_common/workflows/broadcast_operator.py index 07411c358e..e6473426ad 100644 --- a/nvflare/app_common/workflows/broadcast_operator.py +++ b/nvflare/app_common/workflows/broadcast_operator.py @@ -42,32 +42,34 @@ def broadcast_and_wait( self, task_name: str, task_input: Shareable, + fl_ctx: FLContext, targets: Union[List[Client], List[str], None] = None, task_props: Optional[Dict] = None, min_responses: int = 1, abort_signal: Signal = None, ) -> Dict[str, DXO]: task = Task(name=task_name, data=task_input, result_received_cb=self.results_cb, props=task_props) - self.controller.broadcast_and_wait(task, self.fl_ctx, targets, min_responses, 0, abort_signal) + self.controller.broadcast_and_wait(task, fl_ctx, targets, min_responses, 0, abort_signal) return self.results def multicasts_and_wait( self, task_name: str, task_inputs: Dict[str, Shareable], + fl_ctx: FLContext, abort_signal: Signal = None, task_check_period: int = 0.5, ) -> Dict[str, DXO]: tasks: Dict[str, Task] = self.get_tasks(task_name, task_inputs) for client_name in tasks: - self.controller.send(task=tasks[client_name], fl_ctx=self.fl_ctx, targets=[client_name]) + self.controller.send(task=tasks[client_name], fl_ctx=fl_ctx, targets=[client_name]) while self.controller.get_num_standing_tasks(): if abort_signal.triggered: - self.log_info(self.fl_ctx, "Abort signal triggered. Finishing multicasts_and_wait.") + self.log_info(fl_ctx, "Abort signal triggered. Finishing multicasts_and_wait.") return - self.log_debug(self.fl_ctx, "Checking standing tasks to see if multicasts_and_wait finished.") + self.log_debug(fl_ctx, "Checking standing tasks to see if multicasts_and_wait finished.") time.sleep(task_check_period) return self.results diff --git a/nvflare/app_common/workflows/cyclic_ctl.py b/nvflare/app_common/workflows/cyclic_ctl.py index c84ebc7cfd..22ca0e6b70 100644 --- a/nvflare/app_common/workflows/cyclic_ctl.py +++ b/nvflare/app_common/workflows/cyclic_ctl.py @@ -266,7 +266,6 @@ def restore(self, state_data: dict, fl_ctx: FLContext): pass def handle_event(self, event_type, fl_ctx): - super().handle_event(event_type, fl_ctx) if event_type == EventType.JOB_DEAD: client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME) new_client_list = [] diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index 22f475b236..b3fb0a4f75 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -672,7 +672,7 @@ def abort(self, msg: str = ""): self.run_abort_signal.trigger(True) def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == InfoCollector.EVENT_TYPE_GET_STATS: + if event_type == InfoCollector.EVENT_TYPE_SET_STATS: collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR) if collector: if not isinstance(collector, GroupInfoCollector): diff --git a/nvflare/private/fed/server/server_json_config.py b/nvflare/private/fed/server/server_json_config.py index c735e82d44..f58cb8a2f2 100644 --- a/nvflare/private/fed/server/server_json_config.py +++ b/nvflare/private/fed/server/server_json_config.py @@ -42,7 +42,6 @@ def __init__(self, id, controller: Controller): """ self.id = id self.controller = controller - self.controller.set_communicator(WFCommServer()) class ServerJsonConfigurator(FedJsonConfigurator): @@ -128,13 +127,13 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node): return if re.search(r"^workflows\.#[0-9]+$", path): - workflow = self.authorize_and_build_component(element, config_ctx, node) - if not isinstance(workflow, Controller): - raise ConfigError('"workflow" must be a Controller object, but got {}'.format(type(workflow))) + controller = self.authorize_and_build_component(element, config_ctx, node) + if not isinstance(controller, Controller): + raise ConfigError('"controller" must be a Controller object, but got {}'.format(type(controller))) cid = element.get("id", None) if not cid: - cid = type(workflow).__name__ + cid = type(controller).__name__ if not isinstance(cid, str): raise ConfigError('"id" must be str but got {}'.format(type(cid))) @@ -145,8 +144,12 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node): if cid in self.components: raise ConfigError('duplicate component id "{}"'.format(cid)) - self.workflows.append(WorkFlow(cid, workflow)) - self.components[cid] = workflow + communicator = WFCommServer() + self.handlers.append(communicator) + controller.set_communicator(communicator) + + self.workflows.append(WorkFlow(cid, controller)) + self.components[cid] = controller return def _get_all_workflows_ids(self): diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index b675220978..d13877862f 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -217,7 +217,7 @@ def run(self): self.log_info(fl_ctx, "Server runner finished.") def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == InfoCollector.EVENT_TYPE_GET_STATS: + if event_type == InfoCollector.EVENT_TYPE_SET_STATS: collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR) if collector: if not isinstance(collector, GroupInfoCollector): diff --git a/nvflare/widgets/info_collector.py b/nvflare/widgets/info_collector.py index 9f0a4d526c..0812cfac14 100644 --- a/nvflare/widgets/info_collector.py +++ b/nvflare/widgets/info_collector.py @@ -47,6 +47,7 @@ class InfoCollector(Widget): CATEGORY_STATS = "stats" CATEGORY_ERROR = "error" + EVENT_TYPE_SET_STATS = "info_collector.set_stats" EVENT_TYPE_GET_STATS = "info_collector.get_stats" CTX_KEY_STATS_COLLECTOR = "info_collector.stats_collector" @@ -137,7 +138,7 @@ def get_run_stats(self) -> dict: coll = GroupInfoCollector() fl_ctx.set_prop(key=self.CTX_KEY_STATS_COLLECTOR, value=coll, sticky=False, private=True) - engine.fire_event(event_type=self.EVENT_TYPE_GET_STATS, fl_ctx=fl_ctx) + engine.fire_event(event_type=self.EVENT_TYPE_SET_STATS, fl_ctx=fl_ctx) # Get the StatusCollector from the fl_ctx, it could have been updated by other component. coll = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR) return coll.info