diff --git a/nvflare/job_config/base_app_config.py b/nvflare/job_config/base_app_config.py index e5c021d987..5a25d0dadc 100644 --- a/nvflare/job_config/base_app_config.py +++ b/nvflare/job_config/base_app_config.py @@ -54,7 +54,7 @@ def add_ext_script(self, ext_script: str): if not isinstance(ext_script, str): raise RuntimeError(f"ext_script must be type of str, but got {ext_script.__class__}") - if not os.path.exists(ext_script): + if not (os.path.isabs(ext_script) or os.path.exists(ext_script)): raise RuntimeError(f"Could not locate external script: {ext_script}") if not ext_script.endswith(".py"): diff --git a/nvflare/job_config/fed_job_config.py b/nvflare/job_config/fed_job_config.py index f38ba08d40..509e35d382 100644 --- a/nvflare/job_config/fed_job_config.py +++ b/nvflare/job_config/fed_job_config.py @@ -16,6 +16,7 @@ import json import os import shutil +import sys from enum import Enum from tempfile import TemporaryDirectory from typing import Dict @@ -155,9 +156,22 @@ def _get_server_app(self, config_dir, custom_dir, fed_app): def _copy_ext_scripts(self, custom_dir, ext_scripts): for script in ext_scripts: - dest_file = os.path.join(custom_dir, script) - module = "".join(script.rsplit(".py", 1)).replace(os.sep, ".") - self._copy_source_file(custom_dir, module, script, dest_file) + if os.path.exists(script): + if os.path.isabs(script): + relative_script = self._get_relative_script(script) + else: + relative_script = script + dest_file = os.path.join(custom_dir, relative_script) + module = "".join(relative_script.rsplit(".py", 1)).replace(os.sep, ".") + self._copy_source_file(custom_dir, module, script, dest_file) + + def _get_relative_script(self, script): + package_path = "" + for path in sys.path: + if script.startswith(path): + if len(path) > len(package_path): + package_path = path + return script[len(package_path) + 1 :] def _get_class_path(self, obj, custom_dir): module = obj.__module__ @@ -294,15 +308,32 @@ def _get_filters(self, filters, custom_dir): return r def locate_imports(self, sf, dest_file): + """Locate all the import statements from the python script, including the imports across multiple lines, + using the the line break continuing. + + Args: + sf: source file + dest_file: copy to destination file + + Returns: + yield all the imports within the source file + + """ os.makedirs(os.path.dirname(dest_file), exist_ok=True) with open(dest_file, "w") as df: + trimmed = "" for line in sf: df.write(line) - trimmed = line.strip() - if trimmed.startswith("from ") and ("import " in trimmed): - yield trimmed - elif trimmed.startswith("import "): - yield trimmed + trimmed += line.strip() + if trimmed.endswith("\\"): + trimmed = trimmed[0:-1] + trimmed = trimmed.strip() + " " + else: + if trimmed.startswith("from ") and ("import " in trimmed): + yield trimmed + elif trimmed.startswith("import "): + yield trimmed + trimmed = "" def _get_deploy_map(self): deploy_map = {} diff --git a/tests/unit_test/data/job_config/sample_code.data b/tests/unit_test/data/job_config/sample_code.data new file mode 100644 index 0000000000..906937df88 --- /dev/null +++ b/tests/unit_test/data/job_config/sample_code.data @@ -0,0 +1,51 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List + +from\ + nvflare.fuel.f3.drivers.base_driver \ +import \ + BaseDriver + +from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo \ + +from nvflare.fuel.f3.drivers.driver_params import DriverCap + + +class WarpDriver(BaseDriver): + """A dummy driver to test custom driver loading""" + + def __init__(self): + super().__init__() + + @staticmethod + def supported_transports() -> List[str]: + return ["warp"] + + @staticmethod + def capabilities() -> Dict[str, Any]: + return {DriverCap.SEND_HEARTBEAT.value: True, DriverCap.SUPPORT_SSL.value: False} + + def listen(self, connector: ConnectorInfo): + self.connector = connector + + def connect(self, connector: ConnectorInfo): + self.connector = connector + + def shutdown(self): + self.close_all() + + @staticmethod + def get_urls(scheme: str, resources: dict) -> (str, str): + return "warp:enterprise" diff --git a/tests/unit_test/job_config/base_app_config_test.py b/tests/unit_test/job_config/base_app_config_test.py new file mode 100644 index 0000000000..fdcad919dc --- /dev/null +++ b/tests/unit_test/job_config/base_app_config_test.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile + +import pytest + +from nvflare.job_config.base_app_config import BaseAppConfig + + +class TestBaseAppConfig: + def setup_method(self, method): + self.app_config = BaseAppConfig() + + def test_add_relative_script(self): + cwd = os.getcwd() + with tempfile.NamedTemporaryFile(dir=cwd, suffix=".py") as temp_file: + script = os.path.basename(temp_file.name) + self.app_config.add_ext_script(script) + assert script in self.app_config.ext_scripts + + def test_add_ext_script(self): + script = "/scripts/sample.py" + self.app_config.add_ext_script(script) + assert script in self.app_config.ext_scripts + + def test_add_ext_script_error(self): + script = "scripts/sample.py" + with pytest.raises(Exception): + self.app_config.add_ext_script(script) diff --git a/tests/unit_test/job_config/fed_job_config_test.py b/tests/unit_test/job_config/fed_job_config_test.py new file mode 100644 index 0000000000..8e86fd519f --- /dev/null +++ b/tests/unit_test/job_config/fed_job_config_test.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile + +from nvflare.job_config.fed_job_config import FedJobConfig + + +class TestFedJobConfig: + def test_locate_imports(self): + job_config = FedJobConfig(job_name="job_name", min_clients=1) + cwd = os.path.dirname(__file__) + source_file = os.path.join(cwd, "../data/job_config/sample_code.data") + expected = [ + "from typing import Any, Dict, List", + "from nvflare.fuel.f3.drivers.base_driver import BaseDriver", + "from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo ", + "from nvflare.fuel.f3.drivers.driver_params import DriverCap", + ] + with open(source_file, "r") as sf: + with tempfile.NamedTemporaryFile(dir=cwd, suffix=".py") as dest_file: + imports = list(job_config.locate_imports(sf, dest_file=dest_file.name)) + assert imports == expected