diff --git a/nvflare/app_common/ccwf/ccwf_job.py b/nvflare/app_common/ccwf/ccwf_job.py index 4c0060092b..13631ab1d3 100644 --- a/nvflare/app_common/ccwf/ccwf_job.py +++ b/nvflare/app_common/ccwf/ccwf_job.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Any, List, Optional from nvflare.apis.executor import Executor from nvflare.app_common.abstract.aggregator import Aggregator @@ -21,7 +21,7 @@ from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.ccwf.common import Constant, CyclicOrder from nvflare.fuel.utils.validation_utils import check_object_type -from nvflare.job_config.api import FedJob +from nvflare.job_config.api import FedJob, has_add_to_job_method from nvflare.widgets.widget import Widget from .cse_client_ctl import CrossSiteEvalClientController @@ -65,12 +65,12 @@ def __init__( class SwarmClientConfig: def __init__( self, - executor: Executor, - persistor: ModelPersistor, - shareable_generator: ShareableGenerator, - aggregator: Aggregator, - metric_comparator: MetricComparator = None, - model_selector: Widget = None, + executor: Any, + persistor: Any, + shareable_generator: Any, + aggregator: Any, + metric_comparator: Any = None, + model_selector: Any = None, learn_task_check_interval=Constant.LEARN_TASK_CHECK_INTERVAL, learn_task_abort_timeout=Constant.LEARN_TASK_ABORT_TIMEOUT, learn_task_ack_timeout=Constant.LEARN_TASK_ACK_TIMEOUT, @@ -79,16 +79,17 @@ def __init__( min_responses_required: int = 1, wait_time_after_min_resps_received: float = 10.0, ): - check_object_type("executor", executor, Executor) - check_object_type("persistor", persistor, ModelPersistor) - check_object_type("shareable_generator", shareable_generator, ShareableGenerator) - check_object_type("aggregator", aggregator, Aggregator) + # the executor could be a wrapper object that adds real Executor when added to job! + validate_object_for_job("executor", executor, Executor) + validate_object_for_job("persistor", persistor, ModelPersistor) + validate_object_for_job("shareable_generator", shareable_generator, ShareableGenerator) + validate_object_for_job("aggregator", aggregator, Aggregator) if model_selector: - check_object_type("model_selector", model_selector, Widget) + validate_object_for_job("model_selector", model_selector, Widget) if metric_comparator: - check_object_type("metric_comparator", metric_comparator, MetricComparator) + validate_object_for_job("metric_comparator", metric_comparator, MetricComparator) self.executor = executor self.persistor = persistor @@ -134,16 +135,16 @@ def __init__( class CyclicClientConfig: def __init__( self, - executor: Executor, - persistor: ModelPersistor, - shareable_generator: ShareableGenerator, + executor: Any, + persistor: Any, + shareable_generator: Any, learn_task_abort_timeout=Constant.LEARN_TASK_ABORT_TIMEOUT, learn_task_ack_timeout=Constant.LEARN_TASK_ACK_TIMEOUT, final_result_ack_timeout=Constant.FINAL_RESULT_ACK_TIMEOUT, ): - check_object_type("executor", executor, Executor) - check_object_type("persistor", persistor, ModelPersistor) - check_object_type("shareable_generator", shareable_generator, ShareableGenerator) + validate_object_for_job("executor", executor, Executor) + validate_object_for_job("persistor", persistor, ModelPersistor) + validate_object_for_job("shareable_generator", shareable_generator, ShareableGenerator) self.executor = executor self.persistor = persistor @@ -317,3 +318,21 @@ def add_cross_site_eval( get_model_timeout=cse_config.get_model_timeout, ) self.to_clients(client_controller, tasks=["cse_*"]) + + +def validate_object_for_job(name, obj, obj_type): + """Check whether the specified object is valid for job. + The object must either have the add_to_fed_job method or is valid object type. + + Args: + name: name of the object + obj: the object to be checked + obj_type: the object type that the object should be, if it doesn't have the add_to_fed_job method. + + Returns: None + + """ + if has_add_to_job_method(obj): + return + + check_object_type(name, obj, obj_type) diff --git a/nvflare/job_config/api.py b/nvflare/job_config/api.py index ca18e437a6..3fd0c888bd 100644 --- a/nvflare/job_config/api.py +++ b/nvflare/job_config/api.py @@ -29,6 +29,8 @@ SPECIAL_CHARACTERS = '"!@#$%^&*()+?=,<>/' +_ADD_TO_JOB_METHOD_NAME = "add_to_fed_job" + class FedApp: def __init__(self): @@ -252,7 +254,7 @@ def to( else: raise ValueError(f"this object can only be assigned to client, but tried to assign to {target}") - add_to_job_method = getattr(obj, "add_to_fed_job", None) + add_to_job_method = getattr(obj, _ADD_TO_JOB_METHOD_NAME, None) if add_to_job_method is not None: ctx = JobCtx(obj, target, id) result = add_to_job_method(self, ctx, **kwargs) @@ -564,3 +566,8 @@ def check_kwargs(args_to_check: dict, args_expected: dict): for k in args_to_check.keys(): if k not in args_expected: raise ValueError(f"Received unexpected arg '{k}'. " f"Supported args: {args_info}") + + +def has_add_to_job_method(obj: Any) -> bool: + add_to_job_method = getattr(obj, _ADD_TO_JOB_METHOD_NAME, None) + return add_to_job_method is not None and callable(add_to_job_method) diff --git a/nvflare/job_config/script_runner.py b/nvflare/job_config/script_runner.py index b28d71e43b..6848767917 100644 --- a/nvflare/job_config/script_runner.py +++ b/nvflare/job_config/script_runner.py @@ -88,6 +88,7 @@ def add_to_fed_job(self, job, ctx, **kwargs): """ job.check_kwargs(args_to_check=kwargs, args_expected={"tasks": False}) tasks = kwargs.get("tasks", ["*"]) + comp_ids = {} if self._launch_external_process: from nvflare.app_common.launchers.subprocess_launcher import SubprocessLauncher @@ -104,11 +105,13 @@ def add_to_fed_job(self, job, ctx, **kwargs): workspace_dir="{WORKSPACE}", ) pipe_id = job.add_component("pipe", component, ctx) + comp_ids["pipe_id"] = pipe_id component = SubprocessLauncher( script=self._command + " custom/" + os.path.basename(self._script) + " " + self._script_args, ) launcher_id = job.add_component("launcher", component, ctx) + comp_ids["launcher_id"] = launcher_id executor = self._get_ex_process_executor_cls(self._framework)( pipe_id=pipe_id, @@ -126,17 +129,19 @@ def add_to_fed_job(self, job, ctx, **kwargs): workspace_dir="{WORKSPACE}", ) metric_pipe_id = job.add_component("metrics_pipe", component, ctx) + comp_ids["metric_pipe_id"] = metric_pipe_id component = MetricRelay( pipe_id=metric_pipe_id, event_type="fed.analytix_log_stats", ) metric_relay_id = job.add_component("metric_relay", component, ctx) + comp_ids["metric_relay_id"] = metric_relay_id component = ExternalConfigurator( component_ids=[metric_relay_id], ) - job.add_component("config_preparer", component, ctx) + comp_ids["config_preparer_id"] = job.add_component("config_preparer", component, ctx) else: executor = self._get_in_process_executor_cls(self._framework)( task_script_path=os.path.basename(self._script), @@ -146,6 +151,7 @@ def add_to_fed_job(self, job, ctx, **kwargs): job.add_executor(executor, tasks=tasks, ctx=ctx) job.add_resources(resources=[self._script], ctx=ctx) + return comp_ids def _get_ex_process_executor_cls(self, framework: FrameworkType) -> Type[ClientAPILauncherExecutor]: if framework == FrameworkType.PYTORCH: