Skip to content

Commit

Permalink
add communicator to handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster committed Mar 18, 2024
1 parent 1cb0079 commit 7a05331
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 27 deletions.
3 changes: 0 additions & 3 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,3 @@ def cancel_task(

def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None):
self.communicator.cancel_all_tasks(completion_status, fl_ctx)

def handle_event(self, event_type: str, fl_ctx: FLContext):
self.communicator.handle_event(event_type, fl_ctx)
3 changes: 2 additions & 1 deletion nvflare/apis/impl/wf_comm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
event_type (str): all event types, including AppEventType and EventType
fl_ctx (FLContext): FLContext information with current event type
"""
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
if event_type == InfoCollector.EVENT_TYPE_SET_STATS:
self._set_stats(fl_ctx)
self.fire_event(InfoCollector.EVENT_TYPE_GET_STATS, fl_ctx)
elif event_type == EventType.JOB_DEAD:
client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME)
with self._dead_clients_lock:
Expand Down
25 changes: 17 additions & 8 deletions nvflare/app_common/psi/dh_psi/dh_psi_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def pairwise_setup(self, ordered_sites: List[SiteSize]):

bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)
return {site_name: results[site_name].data[PSIConst.SETUP_MSG] for site_name in results}

Expand All @@ -181,7 +181,7 @@ def pairwise_requests(self, ordered_sites: List[SiteSize], setup_msgs: Dict[str,

bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)
return {site_name: results[site_name].data[PSIConst.REQUEST_MSG] for site_name in results}

Expand All @@ -199,7 +199,7 @@ def pairwise_responses(self, ordered_sites: List[SiteSize], request_msgs: Dict[s

bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)
return {site_name: results[site_name].data[PSIConst.RESPONSE_MSG] for site_name in results}

Expand All @@ -217,7 +217,7 @@ def pairwise_intersect(self, ordered_sites: List[SiteSize], response_msg: Dict[s

bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)
return {site_name: results[site_name].data[PSIConst.ITEMS_SIZE] for site_name in results}

Expand Down Expand Up @@ -279,7 +279,7 @@ def calculate_intersections(self, response_msg) -> Dict[str, int]:
task_inputs[client_name] = inputs
bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)

intersects = {client_name: results[client_name].data[PSIConst.ITEMS_SIZE] for client_name in results}
Expand All @@ -292,7 +292,11 @@ def process_requests(self, s: SiteSize, request_msgs: Dict[str, str]) -> Dict[st
task_inputs[PSIConst.REQUEST_MSG_SET] = request_msgs
bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.broadcast_and_wait(
task_name=self.task_name, task_input=task_inputs, targets=[s.name], abort_signal=self.abort_signal
task_name=self.task_name,
task_input=task_inputs,
fl_ctx=self.fl_ctx,
targets=[s.name],
abort_signal=self.abort_signal,
)

dxo = results[s.name]
Expand All @@ -309,7 +313,7 @@ def create_requests(self, site_setup_msgs) -> Dict[str, str]:

bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)
request_msgs = {client_name: results[client_name].data[PSIConst.REQUEST_MSG] for client_name in results}
return request_msgs
Expand All @@ -335,6 +339,7 @@ def prepare_sites(self, abort_signal):
results = bop.broadcast_and_wait(
task_name=self.task_name,
task_input=inputs,
fl_ctx=self.fl_ctx,
targets=targets,
min_responses=min_responses,
abort_signal=abort_signal,
Expand All @@ -352,7 +357,11 @@ def prepare_setup_messages(self, s: SiteSize, other_site_sizes: Set[int]) -> Dic
inputs[PSIConst.ITEMS_SIZE_SET] = other_site_sizes
bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.broadcast_and_wait(
task_name=self.task_name, task_input=inputs, targets=[s.name], abort_signal=self.abort_signal
task_name=self.task_name,
task_input=inputs,
fl_ctx=self.fl_ctx,
targets=[s.name],
abort_signal=self.abort_signal,
)
dxo = results[s.name]
return dxo.data[PSIConst.SETUP_MSG]
10 changes: 6 additions & 4 deletions nvflare/app_common/workflows/broadcast_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,34 @@ def broadcast_and_wait(
self,
task_name: str,
task_input: Shareable,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
task_props: Optional[Dict] = None,
min_responses: int = 1,
abort_signal: Signal = None,
) -> Dict[str, DXO]:
task = Task(name=task_name, data=task_input, result_received_cb=self.results_cb, props=task_props)
self.controller.broadcast_and_wait(task, self.fl_ctx, targets, min_responses, 0, abort_signal)
self.controller.broadcast_and_wait(task, fl_ctx, targets, min_responses, 0, abort_signal)
return self.results

def multicasts_and_wait(
self,
task_name: str,
task_inputs: Dict[str, Shareable],
fl_ctx: FLContext,
abort_signal: Signal = None,
task_check_period: int = 0.5,
) -> Dict[str, DXO]:

tasks: Dict[str, Task] = self.get_tasks(task_name, task_inputs)
for client_name in tasks:
self.controller.send(task=tasks[client_name], fl_ctx=self.fl_ctx, targets=[client_name])
self.controller.send(task=tasks[client_name], fl_ctx=fl_ctx, targets=[client_name])

while self.controller.get_num_standing_tasks():
if abort_signal.triggered:
self.log_info(self.fl_ctx, "Abort signal triggered. Finishing multicasts_and_wait.")
self.log_info(fl_ctx, "Abort signal triggered. Finishing multicasts_and_wait.")
return
self.log_debug(self.fl_ctx, "Checking standing tasks to see if multicasts_and_wait finished.")
self.log_debug(fl_ctx, "Checking standing tasks to see if multicasts_and_wait finished.")
time.sleep(task_check_period)

return self.results
Expand Down
1 change: 0 additions & 1 deletion nvflare/app_common/workflows/cyclic_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def restore(self, state_data: dict, fl_ctx: FLContext):
pass

def handle_event(self, event_type, fl_ctx):
super().handle_event(event_type, fl_ctx)
if event_type == EventType.JOB_DEAD:
client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME)
new_client_list = []
Expand Down
2 changes: 1 addition & 1 deletion nvflare/private/fed/client/client_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def abort(self, msg: str = ""):
self.run_abort_signal.trigger(True)

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
if event_type == InfoCollector.EVENT_TYPE_SET_STATS:
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR)
if collector:
if not isinstance(collector, GroupInfoCollector):
Expand Down
17 changes: 10 additions & 7 deletions nvflare/private/fed/server/server_json_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(self, id, controller: Controller):
"""
self.id = id
self.controller = controller
self.controller.set_communicator(WFCommServer())


class ServerJsonConfigurator(FedJsonConfigurator):
Expand Down Expand Up @@ -128,13 +127,13 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node):
return

if re.search(r"^workflows\.#[0-9]+$", path):
workflow = self.authorize_and_build_component(element, config_ctx, node)
if not isinstance(workflow, Controller):
raise ConfigError('"workflow" must be a Controller object, but got {}'.format(type(workflow)))
controller = self.authorize_and_build_component(element, config_ctx, node)
if not isinstance(controller, Controller):
raise ConfigError('"controller" must be a Controller object, but got {}'.format(type(controller)))

cid = element.get("id", None)
if not cid:
cid = type(workflow).__name__
cid = type(controller).__name__

if not isinstance(cid, str):
raise ConfigError('"id" must be str but got {}'.format(type(cid)))
Expand All @@ -145,8 +144,12 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node):
if cid in self.components:
raise ConfigError('duplicate component id "{}"'.format(cid))

self.workflows.append(WorkFlow(cid, workflow))
self.components[cid] = workflow
communicator = WFCommServer()
self.handlers.append(communicator)
controller.set_communicator(communicator)

self.workflows.append(WorkFlow(cid, controller))
self.components[cid] = controller
return

def _get_all_workflows_ids(self):
Expand Down
2 changes: 1 addition & 1 deletion nvflare/private/fed/server/server_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def run(self):
self.log_info(fl_ctx, "Server runner finished.")

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
if event_type == InfoCollector.EVENT_TYPE_SET_STATS:
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR)
if collector:
if not isinstance(collector, GroupInfoCollector):
Expand Down
3 changes: 2 additions & 1 deletion nvflare/widgets/info_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class InfoCollector(Widget):
CATEGORY_STATS = "stats"
CATEGORY_ERROR = "error"

EVENT_TYPE_SET_STATS = "info_collector.set_stats"
EVENT_TYPE_GET_STATS = "info_collector.get_stats"
CTX_KEY_STATS_COLLECTOR = "info_collector.stats_collector"

Expand Down Expand Up @@ -137,7 +138,7 @@ def get_run_stats(self) -> dict:
coll = GroupInfoCollector()
fl_ctx.set_prop(key=self.CTX_KEY_STATS_COLLECTOR, value=coll, sticky=False, private=True)

engine.fire_event(event_type=self.EVENT_TYPE_GET_STATS, fl_ctx=fl_ctx)
engine.fire_event(event_type=self.EVENT_TYPE_SET_STATS, fl_ctx=fl_ctx)
# Get the StatusCollector from the fl_ctx, it could have been updated by other component.
coll = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR)
return coll.info
Expand Down

0 comments on commit 7a05331

Please sign in to comment.