Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Responder functions #2397

Merged
merged 6 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nvflare/apis/controller_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def stop_controller(self, fl_ctx: FLContext):
"""
pass

@abstractmethod
def process_result_of_unknown_task(
self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext
):
Expand Down
3 changes: 3 additions & 0 deletions nvflare/apis/event_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class EventType(object):
JOB_COMPLETED = "_job_completed"
JOB_ABORTED = "_job_aborted"
JOB_CANCELLED = "_job_cancelled"
JOB_DEAD = "_job_dead"

BEFORE_PULL_TASK = "_before_pull_task"
AFTER_PULL_TASK = "_after_pull_task"
Expand All @@ -50,6 +51,8 @@ class EventType(object):
AFTER_TASK_EXECUTION = "_after_task_execution"
BEFORE_SEND_TASK_RESULT = "_before_send_task_result"
AFTER_SEND_TASK_RESULT = "_after_send_task_result"
BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK = "_before_process_result_of_unknown_task"
AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK = "_after_process_result_of_unknown_task"

CRITICAL_LOG_AVAILABLE = "_critical_log_available"
ERROR_LOG_AVAILABLE = "_error_log_available"
Expand Down
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class FLContextKey(object):
CLIENT_TOKEN = "__client_token"
AUTHORIZATION_RESULT = "_authorization_result"
AUTHORIZATION_REASON = "_authorization_reason"
DEAD_JOB_CLIENT_NAME = "_dead_job_client_name"

CLIENT_REGISTER_DATA = "_client_register_data"
SECURITY_ITEMS = "_security_items"
Expand Down
1 change: 1 addition & 0 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def initialize(self, fl_ctx: FLContext):
return

self._engine = engine
self.start_controller(fl_ctx)

def set_communicator(self, communicator: WFCommSpec):
if not isinstance(communicator, WFCommSpec):
Expand Down
34 changes: 16 additions & 18 deletions nvflare/apis/impl/wf_comm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
Expand Down Expand Up @@ -133,8 +134,8 @@ def _set_stats(self, fl_ctx: FLContext):
raise TypeError(
"collector must be an instance of GroupInfoCollector, but got {}".format(type(collector))
)
collector.set_info(
group_name=self._name,
collector.add_info(
group_name=self.controller._name,
SYangster marked this conversation as resolved.
Show resolved Hide resolved
info={
"tasks": {t.name: [ct.client.name for ct in t.client_tasks] for t in self._tasks},
},
Expand All @@ -149,6 +150,12 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
"""
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
self._set_stats(fl_ctx)
elif event_type == EventType.JOB_DEAD:
client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME)
with self._dead_clients_lock:
self.log_info(fl_ctx, f"received dead job report from client {client_name}")
if not self._dead_client_reports.get(client_name):
self._dead_client_reports[client_name] = time.time()

def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]:
"""Called by runner when a client asks for a task.
Expand Down Expand Up @@ -330,22 +337,6 @@ def handle_exception(self, task_id: str, fl_ctx: FLContext) -> None:
self.cancel_task(task=task, fl_ctx=fl_ctx)
self.log_error(fl_ctx, "task {} is cancelled due to exception".format(task.name))

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
"""Called by the Engine to handle the case that the job on the client is dead.

Args:
client_name: name of the client on which the job is dead
fl_ctx: the FLContext

"""
# record the report and to be used by the task monitor
with self._dead_clients_lock:
self.log_info(fl_ctx, f"received dead job report from client {client_name}")
if not self._dead_client_reports.get(client_name):
self._dead_client_reports[client_name] = time.time()

self.controller.handle_dead_job(client_name, fl_ctx)

def process_task_check(self, task_id: str, fl_ctx: FLContext):
with self._task_lock:
# task_id is the uuid associated with the client_task
Expand Down Expand Up @@ -400,7 +391,14 @@ def _do_process_submission(
if client_task is None:
# cannot find a standing task for the submission
self.log_debug(fl_ctx, "no standing task found for {}:{}".format(task_name, task_id))

self.log_debug(fl_ctx, "firing event EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK")
self.fire_event(EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK, fl_ctx)

self.controller.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx)

self.log_debug(fl_ctx, "firing event EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK")
self.fire_event(EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK, fl_ctx)
return

task = client_task.task
Expand Down
10 changes: 0 additions & 10 deletions nvflare/apis/wf_comm_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,6 @@ def process_task_check(self, task_id: str, fl_ctx: FLContext):
"""
raise NotImplementedError

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
"""Called by the Engine to handle the case that the job on the client is dead.

Args:
client_name: name of the client on which the job is dead
fl_ctx: the FLContext

"""
raise NotImplementedError

def initialize_run(self, fl_ctx: FLContext):
"""Called when a new RUN is about to start.

Expand Down
25 changes: 17 additions & 8 deletions nvflare/app_common/psi/dh_psi/dh_psi_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def pairwise_setup(self, ordered_sites: List[SiteSize]):

bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved
)
return {site_name: results[site_name].data[PSIConst.SETUP_MSG] for site_name in results}

Expand All @@ -181,7 +181,7 @@ def pairwise_requests(self, ordered_sites: List[SiteSize], setup_msgs: Dict[str,

bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)
return {site_name: results[site_name].data[PSIConst.REQUEST_MSG] for site_name in results}

Expand All @@ -199,7 +199,7 @@ def pairwise_responses(self, ordered_sites: List[SiteSize], request_msgs: Dict[s

bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)
return {site_name: results[site_name].data[PSIConst.RESPONSE_MSG] for site_name in results}

Expand All @@ -217,7 +217,7 @@ def pairwise_intersect(self, ordered_sites: List[SiteSize], response_msg: Dict[s

bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)
return {site_name: results[site_name].data[PSIConst.ITEMS_SIZE] for site_name in results}

Expand Down Expand Up @@ -279,7 +279,7 @@ def calculate_intersections(self, response_msg) -> Dict[str, int]:
task_inputs[client_name] = inputs
bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)

intersects = {client_name: results[client_name].data[PSIConst.ITEMS_SIZE] for client_name in results}
Expand All @@ -292,7 +292,11 @@ def process_requests(self, s: SiteSize, request_msgs: Dict[str, str]) -> Dict[st
task_inputs[PSIConst.REQUEST_MSG_SET] = request_msgs
bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.broadcast_and_wait(
task_name=self.task_name, task_input=task_inputs, targets=[s.name], abort_signal=self.abort_signal
task_name=self.task_name,
task_input=task_inputs,
fl_ctx=self.fl_ctx,
targets=[s.name],
abort_signal=self.abort_signal,
)

dxo = results[s.name]
Expand All @@ -309,7 +313,7 @@ def create_requests(self, site_setup_msgs) -> Dict[str, str]:

bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.multicasts_and_wait(
task_name=self.task_name, task_inputs=task_inputs, abort_signal=self.abort_signal
task_name=self.task_name, task_inputs=task_inputs, fl_ctx=self.fl_ctx, abort_signal=self.abort_signal
)
request_msgs = {client_name: results[client_name].data[PSIConst.REQUEST_MSG] for client_name in results}
return request_msgs
Expand All @@ -335,6 +339,7 @@ def prepare_sites(self, abort_signal):
results = bop.broadcast_and_wait(
task_name=self.task_name,
task_input=inputs,
fl_ctx=self.fl_ctx,
targets=targets,
min_responses=min_responses,
abort_signal=abort_signal,
Expand All @@ -352,7 +357,11 @@ def prepare_setup_messages(self, s: SiteSize, other_site_sizes: Set[int]) -> Dic
inputs[PSIConst.ITEMS_SIZE_SET] = other_site_sizes
bop = BroadcastAndWait(self.fl_ctx, self.controller)
results = bop.broadcast_and_wait(
task_name=self.task_name, task_input=inputs, targets=[s.name], abort_signal=self.abort_signal
task_name=self.task_name,
task_input=inputs,
fl_ctx=self.fl_ctx,
targets=[s.name],
abort_signal=self.abort_signal,
)
dxo = results[s.name]
return dxo.data[PSIConst.SETUP_MSG]
17 changes: 12 additions & 5 deletions nvflare/app_common/workflows/broadcast_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import threading
import time
from typing import Dict, List, Optional, Union

from nvflare.apis.client import Client
Expand Down Expand Up @@ -41,29 +42,35 @@ def broadcast_and_wait(
self,
task_name: str,
task_input: Shareable,
fl_ctx: FLContext,
targets: Union[List[Client], List[str], None] = None,
task_props: Optional[Dict] = None,
min_responses: int = 1,
abort_signal: Signal = None,
) -> Dict[str, DXO]:
task = Task(name=task_name, data=task_input, result_received_cb=self.results_cb, props=task_props)
self.controller.broadcast_and_wait(task, self.fl_ctx, targets, min_responses, 0, abort_signal)
self.controller.broadcast_and_wait(task, fl_ctx, targets, min_responses, 0, abort_signal)
return self.results

def multicasts_and_wait(
self,
task_name: str,
task_inputs: Dict[str, Shareable],
fl_ctx: FLContext,
abort_signal: Signal = None,
task_check_period: int = 0.5,
) -> Dict[str, DXO]:

tasks: Dict[str, Task] = self.get_tasks(task_name, task_inputs)
for client_name in tasks:
self.controller.broadcast(task=tasks[client_name], fl_ctx=self.fl_ctx, targets=[client_name])
self.controller.send(task=tasks[client_name], fl_ctx=fl_ctx, targets=[client_name])

for client_name in tasks:
self.log_info(self.fl_ctx, f"wait for client {client_name} task")
self.controller.wait_for_task(tasks[client_name], abort_signal)
while self.controller.get_num_standing_tasks():
if abort_signal.triggered:
self.log_info(fl_ctx, "Abort signal triggered. Finishing multicasts_and_wait.")
return
self.log_debug(fl_ctx, "Checking standing tasks to see if multicasts_and_wait finished.")
time.sleep(task_check_period)

return self.results

Expand Down
19 changes: 10 additions & 9 deletions nvflare/app_common/workflows/cyclic_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
Expand Down Expand Up @@ -264,11 +265,11 @@ def restore(self, state_data: dict, fl_ctx: FLContext):
finally:
pass

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
super().handle_dead_job(client_name, fl_ctx)

new_client_list = []
for client in self._participating_clients:
if client_name != client.name:
new_client_list.append(client)
self._participating_clients = new_client_list
def handle_event(self, event_type, fl_ctx):
if event_type == EventType.JOB_DEAD:
client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME)
new_client_list = []
for client in self._participating_clients:
if client_name != client.name:
new_client_list.append(client)
self._participating_clients = new_client_list
SYangster marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion nvflare/app_common/workflows/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def start_controller(self, fl_ctx: FLContext) -> None:
self.event(AppEventType.INITIAL_MODEL_LOADED)
YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved

self.engine = self.fl_ctx.get_engine()
self.initialize()
FLComponentWrapper.initialize(self)

def _build_shareable(self, data: FLModel = None) -> Shareable:
if not data: # if no data is given, send self.model
Expand Down
17 changes: 10 additions & 7 deletions nvflare/private/fed/server/server_json_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(self, id, controller: Controller):
"""
self.id = id
self.controller = controller
self.controller.set_communicator(WFCommServer())


class ServerJsonConfigurator(FedJsonConfigurator):
Expand Down Expand Up @@ -128,13 +127,13 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node):
return

if re.search(r"^workflows\.#[0-9]+$", path):
workflow = self.authorize_and_build_component(element, config_ctx, node)
if not isinstance(workflow, Controller):
raise ConfigError('"workflow" must be a Controller object, but got {}'.format(type(workflow)))
controller = self.authorize_and_build_component(element, config_ctx, node)
if not isinstance(controller, Controller):
raise ConfigError('"controller" must be a Controller object, but got {}'.format(type(controller)))

cid = element.get("id", None)
if not cid:
cid = type(workflow).__name__
cid = type(controller).__name__

if not isinstance(cid, str):
raise ConfigError('"id" must be str but got {}'.format(type(cid)))
Expand All @@ -145,8 +144,12 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node):
if cid in self.components:
raise ConfigError('duplicate component id "{}"'.format(cid))

self.workflows.append(WorkFlow(cid, workflow))
self.components[cid] = workflow
communicator = WFCommServer()
self.handlers.append(communicator)
controller.set_communicator(communicator)

self.workflows.append(WorkFlow(cid, controller))
self.components[cid] = controller
return

def _get_all_workflows_ids(self):
Expand Down
8 changes: 5 additions & 3 deletions nvflare/private/fed/server/server_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ def _execute_run(self):

fl_ctx.set_prop(FLContextKey.WORKFLOW, wf.id, sticky=True)

wf.controller.initialize(fl_ctx)
wf.controller.communicator.initialize_run(fl_ctx)
wf.controller.start_controller(fl_ctx)
wf.controller.initialize(fl_ctx)

self.log_info(fl_ctx, "Workflow {} ({}) started".format(wf.id, type(wf.controller)))
self.log_debug(fl_ctx, "firing event EventType.START_WORKFLOW")
Expand Down Expand Up @@ -381,7 +380,10 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
if self.current_wf is None:
return

self.current_wf.controller.communicator.handle_dead_job(client_name=client_name, fl_ctx=fl_ctx)
fl_ctx.set_prop(FLContextKey.DEAD_JOB_CLIENT_NAME, client_name)
self.log_debug(fl_ctx, "firing event EventType.JOB_DEAD")
self.fire_event(EventType.JOB_DEAD, fl_ctx)

except Exception as e:
self.log_exception(
fl_ctx, f"Error processing dead job by workflow {self.current_wf.id}: {secure_format_exception(e)}"
Expand Down
Loading