Skip to content

Commit

Permalink
remove wait_for_task, handle_dead_job from controller, add some alter…
Browse files Browse the repository at this point in the history
…native events
  • Loading branch information
SYangster committed Mar 15, 2024
1 parent 529985c commit 82cd200
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 71 deletions.
18 changes: 0 additions & 18 deletions nvflare/apis/controller_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def stop_controller(self, fl_ctx: FLContext):
"""
pass

@abstractmethod
def process_result_of_unknown_task(
self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext
):
Expand Down Expand Up @@ -543,20 +542,3 @@ def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_
fl_ctx: the FL context
"""
pass

def wait_for_task(self, task: Task, abort_signal: Signal):
"""Wait for task to complete or abort.
Args:
task: the task to wait for
abort_signal: the abort signal. If triggered, this method stops waiting and returns to the caller.
"""
pass

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
"""Called by the Engine to handle the case that the job on the client is dead.
Args:
client_name: name of the client on which the job is dead
fl_ctx: the FLContext
"""
pass
3 changes: 3 additions & 0 deletions nvflare/apis/event_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class EventType(object):
JOB_COMPLETED = "_job_completed"
JOB_ABORTED = "_job_aborted"
JOB_CANCELLED = "_job_cancelled"
JOB_DEAD = "_job_dead"

BEFORE_PULL_TASK = "_before_pull_task"
AFTER_PULL_TASK = "_after_pull_task"
Expand All @@ -50,6 +51,8 @@ class EventType(object):
AFTER_TASK_EXECUTION = "_after_task_execution"
BEFORE_SEND_TASK_RESULT = "_before_send_task_result"
AFTER_SEND_TASK_RESULT = "_after_send_task_result"
BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK = "_before_process_result_of_unknown_task"
AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK = "_after_process_result_of_unknown_task"

CRITICAL_LOG_AVAILABLE = "_critical_log_available"
ERROR_LOG_AVAILABLE = "_error_log_available"
Expand Down
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class FLContextKey(object):
CLIENT_TOKEN = "__client_token"
AUTHORIZATION_RESULT = "_authorization_result"
AUTHORIZATION_REASON = "_authorization_reason"
DEAD_JOB_CLIENT_NAME = "_dead_job_client_name"

CLIENT_REGISTER_DATA = "_client_register_data"
SECURITY_ITEMS = "_security_items"
Expand Down
6 changes: 0 additions & 6 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,5 @@ def cancel_task(
def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None):
self.communicator.cancel_all_tasks(completion_status, fl_ctx)

def wait_for_task(self, task: Task, abort_signal: Signal):
self.communicator.wait_for_task(task, abort_signal)

def handle_event(self, event_type: str, fl_ctx: FLContext):
self.communicator.handle_event(event_type, fl_ctx)

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
self.communicator.handle_dead_job(client_name, fl_ctx)
28 changes: 14 additions & 14 deletions nvflare/apis/impl/wf_comm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
Expand Down Expand Up @@ -149,6 +150,12 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
"""
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
self._set_stats(fl_ctx)
elif event_type == EventType.JOB_DEAD:
client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME)
with self._dead_clients_lock:
self.log_info(fl_ctx, f"received dead job report from client {client_name}")
if not self._dead_client_reports.get(client_name):
self._dead_client_reports[client_name] = time.time()

def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]:
"""Called by runner when a client asks for a task.
Expand Down Expand Up @@ -330,20 +337,6 @@ def handle_exception(self, task_id: str, fl_ctx: FLContext) -> None:
self.cancel_task(task=task, fl_ctx=fl_ctx)
self.log_error(fl_ctx, "task {} is cancelled due to exception".format(task.name))

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
"""Called by the Engine to handle the case that the job on the client is dead.
Args:
client_name: name of the client on which the job is dead
fl_ctx: the FLContext
"""
# record the report and to be used by the task monitor
with self._dead_clients_lock:
self.log_info(fl_ctx, f"received dead job report from client {client_name}")
if not self._dead_client_reports.get(client_name):
self._dead_client_reports[client_name] = time.time()

def process_task_check(self, task_id: str, fl_ctx: FLContext):
with self._task_lock:
# task_id is the uuid associated with the client_task
Expand Down Expand Up @@ -398,7 +391,14 @@ def _do_process_submission(
if client_task is None:
# cannot find a standing task for the submission
self.log_debug(fl_ctx, "no standing task found for {}:{}".format(task_name, task_id))

self.log_debug(fl_ctx, "firing event EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK")
self.fire_event(EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK, fl_ctx)

self.controller.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx)

self.log_debug(fl_ctx, "firing event EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK")
self.fire_event(EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK, fl_ctx)
return

task = client_task.task
Expand Down
19 changes: 0 additions & 19 deletions nvflare/apis/wf_comm_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,6 @@ def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_
"""
raise NotImplementedError

def wait_for_task(self, task: Task, abort_signal: Signal):
"""Wait for task to complete or abort.
Args:
task: the task to wait for
abort_signal: the abort signal. If triggered, this method stops waiting and returns to the caller.
"""
raise NotImplementedError

def check_tasks(self):
"""Checks if tasks should be exited."""
raise NotImplementedError
Expand Down Expand Up @@ -302,16 +293,6 @@ def process_task_check(self, task_id: str, fl_ctx: FLContext):
"""
raise NotImplementedError

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
"""Called by the Engine to handle the case that the job on the client is dead.
Args:
client_name: name of the client on which the job is dead
fl_ctx: the FLContext
"""
raise NotImplementedError

def initialize_run(self, fl_ctx: FLContext):
"""Called when a new RUN is about to start.
Expand Down
13 changes: 9 additions & 4 deletions nvflare/app_common/workflows/broadcast_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import threading
import time
from typing import Dict, List, Optional, Union

from nvflare.apis.client import Client
Expand Down Expand Up @@ -55,15 +56,19 @@ def multicasts_and_wait(
task_name: str,
task_inputs: Dict[str, Shareable],
abort_signal: Signal = None,
task_check_period: int = 0.5,
) -> Dict[str, DXO]:

tasks: Dict[str, Task] = self.get_tasks(task_name, task_inputs)
for client_name in tasks:
self.controller.broadcast(task=tasks[client_name], fl_ctx=self.fl_ctx, targets=[client_name])
self.controller.send(task=tasks[client_name], fl_ctx=self.fl_ctx, targets=[client_name])

for client_name in tasks:
self.log_info(self.fl_ctx, f"wait for client {client_name} task")
self.controller.wait_for_task(tasks[client_name], abort_signal)
while self.controller.get_num_standing_tasks():
if abort_signal.triggered:
self.log_info(self.fl_ctx, "Abort signal triggered. Finishing multicasts_and_wait.")
return
self.log_debug(self.fl_ctx, "Checking standing tasks to see if multicasts_and_wait finished.")
time.sleep(task_check_period)

return self.results

Expand Down
19 changes: 10 additions & 9 deletions nvflare/app_common/workflows/cyclic_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
Expand Down Expand Up @@ -264,11 +265,11 @@ def restore(self, state_data: dict, fl_ctx: FLContext):
finally:
pass

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
super().handle_dead_job(client_name, fl_ctx)

new_client_list = []
for client in self._participating_clients:
if client_name != client.name:
new_client_list.append(client)
self._participating_clients = new_client_list
def handle_event(self, event_type, fl_ctx):
if event_type == EventType.JOB_DEAD:
client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME)
new_client_list = []
for client in self._participating_clients:
if client_name != client.name:
new_client_list.append(client)
self._participating_clients = new_client_list
5 changes: 4 additions & 1 deletion nvflare/private/fed/server/server_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,10 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
if self.current_wf is None:
return

self.current_wf.controller.handle_dead_job(client_name=client_name, fl_ctx=fl_ctx)
fl_ctx.set_prop(FLContextKey.DEAD_JOB_CLIENT_NAME, client_name)
self.log_debug(fl_ctx, "firing event EventType.JOB_DEAD")
self.fire_event(EventType.JOB_DEAD, fl_ctx)

except Exception as e:
self.log_exception(
fl_ctx, f"Error processing dead job by workflow {self.current_wf.id}: {secure_format_exception(e)}"
Expand Down

0 comments on commit 82cd200

Please sign in to comment.