Skip to content

Commit

Permalink
Support ScriptRunner in ccwf_job (#2825)
Browse files Browse the repository at this point in the history
* support ScriptRunner in ccwf_job

* remove unused import

* added object type check
  • Loading branch information
yanchengnv authored Aug 22, 2024
1 parent f766f90 commit d7c92cf
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 22 deletions.
59 changes: 39 additions & 20 deletions nvflare/app_common/ccwf/ccwf_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from typing import Any, List, Optional

from nvflare.apis.executor import Executor
from nvflare.app_common.abstract.aggregator import Aggregator
Expand All @@ -21,7 +21,7 @@
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.ccwf.common import Constant, CyclicOrder
from nvflare.fuel.utils.validation_utils import check_object_type
from nvflare.job_config.api import FedJob
from nvflare.job_config.api import FedJob, has_add_to_job_method
from nvflare.widgets.widget import Widget

from .cse_client_ctl import CrossSiteEvalClientController
Expand Down Expand Up @@ -65,12 +65,12 @@ def __init__(
class SwarmClientConfig:
def __init__(
self,
executor: Executor,
persistor: ModelPersistor,
shareable_generator: ShareableGenerator,
aggregator: Aggregator,
metric_comparator: MetricComparator = None,
model_selector: Widget = None,
executor: Any,
persistor: Any,
shareable_generator: Any,
aggregator: Any,
metric_comparator: Any = None,
model_selector: Any = None,
learn_task_check_interval=Constant.LEARN_TASK_CHECK_INTERVAL,
learn_task_abort_timeout=Constant.LEARN_TASK_ABORT_TIMEOUT,
learn_task_ack_timeout=Constant.LEARN_TASK_ACK_TIMEOUT,
Expand All @@ -79,16 +79,17 @@ def __init__(
min_responses_required: int = 1,
wait_time_after_min_resps_received: float = 10.0,
):
check_object_type("executor", executor, Executor)
check_object_type("persistor", persistor, ModelPersistor)
check_object_type("shareable_generator", shareable_generator, ShareableGenerator)
check_object_type("aggregator", aggregator, Aggregator)
# the executor could be a wrapper object that adds real Executor when added to job!
validate_object_for_job("executor", executor, Executor)
validate_object_for_job("persistor", persistor, ModelPersistor)
validate_object_for_job("shareable_generator", shareable_generator, ShareableGenerator)
validate_object_for_job("aggregator", aggregator, Aggregator)

if model_selector:
check_object_type("model_selector", model_selector, Widget)
validate_object_for_job("model_selector", model_selector, Widget)

if metric_comparator:
check_object_type("metric_comparator", metric_comparator, MetricComparator)
validate_object_for_job("metric_comparator", metric_comparator, MetricComparator)

self.executor = executor
self.persistor = persistor
Expand Down Expand Up @@ -134,16 +135,16 @@ def __init__(
class CyclicClientConfig:
def __init__(
self,
executor: Executor,
persistor: ModelPersistor,
shareable_generator: ShareableGenerator,
executor: Any,
persistor: Any,
shareable_generator: Any,
learn_task_abort_timeout=Constant.LEARN_TASK_ABORT_TIMEOUT,
learn_task_ack_timeout=Constant.LEARN_TASK_ACK_TIMEOUT,
final_result_ack_timeout=Constant.FINAL_RESULT_ACK_TIMEOUT,
):
check_object_type("executor", executor, Executor)
check_object_type("persistor", persistor, ModelPersistor)
check_object_type("shareable_generator", shareable_generator, ShareableGenerator)
validate_object_for_job("executor", executor, Executor)
validate_object_for_job("persistor", persistor, ModelPersistor)
validate_object_for_job("shareable_generator", shareable_generator, ShareableGenerator)

self.executor = executor
self.persistor = persistor
Expand Down Expand Up @@ -317,3 +318,21 @@ def add_cross_site_eval(
get_model_timeout=cse_config.get_model_timeout,
)
self.to_clients(client_controller, tasks=["cse_*"])


def validate_object_for_job(name, obj, obj_type):
"""Check whether the specified object is valid for job.
The object must either have the add_to_fed_job method or is valid object type.
Args:
name: name of the object
obj: the object to be checked
obj_type: the object type that the object should be, if it doesn't have the add_to_fed_job method.
Returns: None
"""
if has_add_to_job_method(obj):
return

check_object_type(name, obj, obj_type)
9 changes: 8 additions & 1 deletion nvflare/job_config/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

SPECIAL_CHARACTERS = '"!@#$%^&*()+?=,<>/'

_ADD_TO_JOB_METHOD_NAME = "add_to_fed_job"


class FedApp:
def __init__(self):
Expand Down Expand Up @@ -252,7 +254,7 @@ def to(
else:
raise ValueError(f"this object can only be assigned to client, but tried to assign to {target}")

add_to_job_method = getattr(obj, "add_to_fed_job", None)
add_to_job_method = getattr(obj, _ADD_TO_JOB_METHOD_NAME, None)
if add_to_job_method is not None:
ctx = JobCtx(obj, target, id)
result = add_to_job_method(self, ctx, **kwargs)
Expand Down Expand Up @@ -564,3 +566,8 @@ def check_kwargs(args_to_check: dict, args_expected: dict):
for k in args_to_check.keys():
if k not in args_expected:
raise ValueError(f"Received unexpected arg '{k}'. " f"Supported args: {args_info}")


def has_add_to_job_method(obj: Any) -> bool:
add_to_job_method = getattr(obj, _ADD_TO_JOB_METHOD_NAME, None)
return add_to_job_method is not None and callable(add_to_job_method)
8 changes: 7 additions & 1 deletion nvflare/job_config/script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def add_to_fed_job(self, job, ctx, **kwargs):
"""
job.check_kwargs(args_to_check=kwargs, args_expected={"tasks": False})
tasks = kwargs.get("tasks", ["*"])
comp_ids = {}

if self._launch_external_process:
from nvflare.app_common.launchers.subprocess_launcher import SubprocessLauncher
Expand All @@ -104,11 +105,13 @@ def add_to_fed_job(self, job, ctx, **kwargs):
workspace_dir="{WORKSPACE}",
)
pipe_id = job.add_component("pipe", component, ctx)
comp_ids["pipe_id"] = pipe_id

component = SubprocessLauncher(
script=self._command + " custom/" + os.path.basename(self._script) + " " + self._script_args,
)
launcher_id = job.add_component("launcher", component, ctx)
comp_ids["launcher_id"] = launcher_id

executor = self._get_ex_process_executor_cls(self._framework)(
pipe_id=pipe_id,
Expand All @@ -126,17 +129,19 @@ def add_to_fed_job(self, job, ctx, **kwargs):
workspace_dir="{WORKSPACE}",
)
metric_pipe_id = job.add_component("metrics_pipe", component, ctx)
comp_ids["metric_pipe_id"] = metric_pipe_id

component = MetricRelay(
pipe_id=metric_pipe_id,
event_type="fed.analytix_log_stats",
)
metric_relay_id = job.add_component("metric_relay", component, ctx)
comp_ids["metric_relay_id"] = metric_relay_id

component = ExternalConfigurator(
component_ids=[metric_relay_id],
)
job.add_component("config_preparer", component, ctx)
comp_ids["config_preparer_id"] = job.add_component("config_preparer", component, ctx)
else:
executor = self._get_in_process_executor_cls(self._framework)(
task_script_path=os.path.basename(self._script),
Expand All @@ -146,6 +151,7 @@ def add_to_fed_job(self, job, ctx, **kwargs):
job.add_executor(executor, tasks=tasks, ctx=ctx)

job.add_resources(resources=[self._script], ctx=ctx)
return comp_ids

def _get_ex_process_executor_cls(self, framework: FrameworkType) -> Type[ClientAPILauncherExecutor]:
if framework == FrameworkType.PYTORCH:
Expand Down

0 comments on commit d7c92cf

Please sign in to comment.