Skip to content

Commit

Permalink
BugFix: InProcessClientAPIExecutor's TaskScriptRunner (#2558)
Browse files Browse the repository at this point in the history
* 1) find script full path to indicate which site script to avoid loading run script
2) make sure the task script failed will cause the client to return failure status which will trigger job stop rather wait forever
3) add different unit tests

* sort key in unit test

* add logic to improve error message

* style format

* add more tests and logics

* code format

* code format

* fix steps error

* fix global steps

* rollback some changes and split it into another PR

* rollback some changes and split it into another PR

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <yuantingh@nvidia.com>
  • Loading branch information
chesterxgchen and YuanTingHsieh authored May 11, 2024
1 parent 27726f1 commit d050cf2
Show file tree
Hide file tree
Showing 13 changed files with 537 additions and 34 deletions.
15 changes: 11 additions & 4 deletions nvflare/app_common/executors/in_process_client_api_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
submit_model_task_name: str = "submit_model",
):
super(InProcessClientAPIExecutor, self).__init__()
self._abort = False
self._client_api = None
self._result_pull_interval = result_pull_interval
self._log_pull_interval = log_pull_interval
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
self._event_manager = EventManager(self._data_bus)
self._data_bus.subscribe([TOPIC_LOCAL_RESULT], self.local_result_callback)
self._data_bus.subscribe([TOPIC_LOG_DATA], self.log_result_callback)
self._data_bus.subscribe([TOPIC_ABORT, TOPIC_STOP], self.to_abort_callback)
self.local_result = None
self._fl_ctx = None
self._task_fn_path = None
Expand All @@ -106,17 +108,19 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
self._init_converter(fl_ctx)

self._task_fn_wrapper = TaskScriptRunner(
script_path=self._task_script_path, script_args=self._task_script_args
site_name=fl_ctx.get_identity_name(),
script_path=self._task_script_path,
script_args=self._task_script_args,
)

self._task_fn_thread = threading.Thread(target=self._task_fn_wrapper.run)
self._task_fn_thread.start()

meta = self._prepare_task_meta(fl_ctx, None)
self._client_api = InProcessClientAPI(task_metadata=meta, result_check_interval=self._result_pull_interval)
self._client_api.init()
self._data_bus.put_data(CLIENT_API_KEY, self._client_api)

self._task_fn_thread.start()

elif event_type == EventType.END_RUN:
self._event_manager.fire_event(TOPIC_STOP, "END_RUN received")
if self._task_fn_thread:
Expand All @@ -142,7 +146,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
# wait for result
self.log_info(fl_ctx, "Waiting for result from peer")
while True:
if abort_signal.triggered:
if abort_signal.triggered or self._abort is True:
# notify peer that the task is aborted
self._event_manager.fire_event(TOPIC_ABORT, f"{task_name}' is aborted, abort_signal_triggered")
return make_reply(ReturnCode.TASK_ABORTED)
Expand Down Expand Up @@ -231,3 +235,6 @@ def log_result_callback(self, topic, data, databus):
# fire_fed_event = True w/o fed_event_converter somehow did not work
with self._engine.new_context() as fl_ctx:
send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=ANALYTIC_EVENT_TYPE, fire_fed_event=False)

def to_abort_callback(self, topic, data, databus):
self._abort = True
60 changes: 43 additions & 17 deletions nvflare/app_common/executors/task_script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,75 +14,101 @@
import builtins
import logging
import os
import runpy
import sys
import traceback

from nvflare.client.in_process.api import TOPIC_ABORT
from nvflare.fuel.data_event.data_bus import DataBus
from nvflare.fuel.data_event.event_manager import EventManager

print_fn = builtins.print


class TaskScriptRunner:
logger = logging.getLogger(__name__)

def __init__(self, script_path: str, script_args: str = None):
def __init__(self, site_name: str, script_path: str, script_args: str = None, redirect_print_to_log=True):
"""Wrapper for function given function path and args
Args:
site_name (str): site name
script_path (str): script file name, such as train.py
script_args (str, Optional): script arguments to pass in.
"""

self.redirect_print_to_log = redirect_print_to_log
self.event_manager = EventManager(DataBus())
self.script_args = script_args
self.client_api = None
self.site_name = site_name
self.logger = logging.getLogger(self.__class__.__name__)
self.script_path = self.get_script_full_path(script_path)
self.script_path = script_path
self.script_full_path = self.get_script_full_path(self.site_name, self.script_path)

def run(self):
"""Call the task_fn with any required arguments."""
self.logger.info(f"\n start task run() with {self.script_path}")
self.logger.info(f"\n start task run() with full path: {self.script_full_path}")
try:
import runpy

curr_argv = sys.argv
builtins.print = log_print
builtins.print = log_print if self.redirect_print_to_log else print_fn
sys.argv = self.get_sys_argv()
runpy.run_path(self.script_path, run_name="__main__")
runpy.run_path(self.script_full_path, run_name="__main__")
sys.argv = curr_argv

except ImportError as ie:
msg = "attempted relative import with no known parent package"
if ie.msg == msg:
xs = [p for p in sys.path if self.script_full_path.startswith(p)]
import_base_path = max(xs, key=len)
raise ImportError(
f"{ie.msg}, the relative import is not support. python import is based off the sys.path: {import_base_path}"
)
else:
raise ie
except Exception as e:
msg = traceback.format_exc()
self.logger.error(msg)
if self.client_api:
self.client_api.exec_queue.ask_abort(msg)
self.logger.error("fire abort event")
self.event_manager.fire_event(TOPIC_ABORT, f"'{self.script_full_path}' is aborted, {msg}")
raise e
finally:
builtins.print = print_fn

def get_sys_argv(self):
args_list = [] if not self.script_args else self.script_args.split()
return [self.script_path] + args_list
return [self.script_full_path] + args_list

def get_script_full_path(self, script_path) -> str:
def get_script_full_path(self, site_name, script_path) -> str:
target_file = None
script_filename = os.path.basename(script_path)
script_dirs = os.path.dirname(script_path)

if os.path.isabs(script_path):
if not os.path.isfile(script_path):
raise ValueError(f"script_path='{script_path}' not found")
return script_path

for r, dirs, files in os.walk(os.getcwd()):
for f in files:
absolute_path = os.path.join(r, f)
if absolute_path.endswith(script_path):
parent_dir = absolute_path[: absolute_path.find(script_path)].rstrip(os.sep)
if os.path.isdir(parent_dir):
target_file = absolute_path
break
path_components = parent_dir.split(os.path.sep)
if site_name in path_components:
target_file = absolute_path
break

if not script_dirs and f == script_filename:
if not site_name and not script_dirs and f == script_filename:
target_file = absolute_path
break

if target_file:
break

if not target_file:
raise ValueError(f"{script_path} is not found")
msg = f"Can not find {script_path}"
self.event_manager.fire_event(TOPIC_ABORT, f"'{self.script_path}' is aborted, {msg}")
raise ValueError(msg)
return target_file


Expand Down
174 changes: 162 additions & 12 deletions tests/unit_test/app_common/executors/task_script_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest

import pytest

from nvflare.app_common.executors.task_script_runner import TaskScriptRunner
from nvflare.client.in_process.api import TOPIC_ABORT, TOPIC_STOP


class TestTaskScriptRunner(unittest.TestCase):
def test_app_scripts_and_args(self):
curr_dir = os.getcwd()
script_path = "nvflare/cli.py"
script_args = "--batch_size 4"
wrapper = TaskScriptRunner(script_path=script_path, script_args=script_args)
wrapper = TaskScriptRunner(site_name="", script_path=script_path, script_args=script_args)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertTrue(wrapper.script_full_path.endswith(script_path))
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "cli.py"), "--batch_size", "4"])

def test_app_scripts_and_args2(self):
curr_dir = os.getcwd()
script_path = "cli.py"
script_args = "--batch_size 4"
wrapper = TaskScriptRunner(script_path=script_path, script_args=script_args)
wrapper = TaskScriptRunner(site_name="", script_path=script_path, script_args=script_args)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertTrue(wrapper.script_full_path.endswith(script_path))
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "cli.py"), "--batch_size", "4"])

def test_app_scripts_with_sub_dirs1(self):
curr_dir = os.getcwd()
script_path = "nvflare/__init__.py"
wrapper = TaskScriptRunner(script_path=script_path)
wrapper = TaskScriptRunner(site_name="", script_path=script_path)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertTrue(wrapper.script_full_path.endswith(script_path))
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "__init__.py")])

def test_app_scripts_with_sub_dirs2(self):
curr_dir = os.getcwd()
script_path = "nvflare/app_common/executors/__init__.py"
wrapper = TaskScriptRunner(script_path=script_path)
wrapper = TaskScriptRunner(site_name="", script_path=script_path)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertTrue(wrapper.script_full_path.endswith(script_path))
self.assertEqual(
wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "app_common", "executors", "__init__.py")]
)

def test_app_scripts_with_sub_dirs3(self):
curr_dir = os.getcwd()
script_path = "executors/task_script_runner.py"
wrapper = TaskScriptRunner(script_path=script_path)
wrapper = TaskScriptRunner(site_name="app_common", script_path=script_path)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertTrue(wrapper.script_full_path.endswith(script_path))
self.assertEqual(
wrapper.get_sys_argv(),
[os.path.join(curr_dir, "nvflare", "app_common", "executors", "task_script_runner.py")],
Expand All @@ -68,7 +72,153 @@ def test_app_scripts_with_sub_dirs3(self):
def test_app_scripts_with_sub_dirs4(self):
curr_dir = os.getcwd()
script_path = "in_process/api.py"
wrapper = TaskScriptRunner(script_path=script_path)
wrapper = TaskScriptRunner(site_name="client", script_path=script_path)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertTrue(wrapper.script_full_path.endswith(script_path))
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "client", "in_process", "api.py")])

def test_file_not_found_with_exception(self):
curr_dir = os.getcwd()
script_path = "in_process/api.py"
with pytest.raises(ValueError, match="Can not find in_process/api.py"):
wrapper = TaskScriptRunner(site_name="site-1", script_path=script_path)
self.assertTrue(wrapper.script_full_path.endswith(script_path))
self.assertEqual(
wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "client", "in_process", "api.py")]
)

def test_run_scripts_with_sub_dirs(self):
old_sys_path = sys.path
script_args = "--batch_size 4"
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/custom"))
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/server/custom"))
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/site-1/custom"))

try:
script_path = "train.py"
wrapper = TaskScriptRunner(
site_name="site-1", script_path=script_path, script_args=script_args, redirect_print_to_log=False
)
self.assertTrue(wrapper.script_full_path.endswith(script_path))
expected_path = os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/site-1/custom/train.py")
self.assertEqual(wrapper.get_sys_argv(), [expected_path, "--batch_size", "4"])
wrapper.run()
finally:
sys.path = old_sys_path

def test_run_scripts_with_sub_dirs2(self):
old_sys_path = sys.path
script_args = "--batch_size 4"
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/custom"))
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/server/custom"))
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/site-1/custom"))

try:
script_path = "train.py"
wrapper = TaskScriptRunner(
site_name="server", script_path=script_path, script_args=script_args, redirect_print_to_log=False
)
self.assertTrue(wrapper.script_full_path.endswith(script_path))
expected_path = os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/server/custom/train.py")
self.assertEqual(wrapper.get_sys_argv(), [expected_path, "--batch_size", "4"])
wrapper.run()
finally:
sys.path = old_sys_path

def test_run_scripts_with_sub_dirs3(self):
old_sys_path = sys.path
script_args = "--batch_size 4"
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/custom"))
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/server/custom"))
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/site-1/custom"))

try:
script_path = "src/train.py"
wrapper = TaskScriptRunner(
site_name="", script_path=script_path, script_args=script_args, redirect_print_to_log=False
)
self.assertTrue(wrapper.script_full_path.endswith(script_path))
expected_path = os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/custom/src/train.py")
self.assertEqual(wrapper.get_sys_argv(), [expected_path, "--batch_size", "4"])
wrapper.run()
finally:
sys.path = old_sys_path

def test_run_failed_scripts(self):
old_sys_path = sys.path
script_args = "--batch_size 4"
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/custom"))
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/server/custom"))
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/site-1/custom"))

try:
script_path = "failed_train.py"
wrapper = TaskScriptRunner(
site_name="site-1", script_path=script_path, script_args=script_args, redirect_print_to_log=False
)
wrapper.event_manager.data_bus.subscribe([TOPIC_ABORT, TOPIC_STOP], self.abort_callback)

self.assertTrue(wrapper.script_full_path.endswith(script_path))
with pytest.raises(ValueError, match="failed to train model"):
# 1 ) check if the exception is through,
# 2 ) more important to see if the callback is trigger.
wrapper.run()
finally:
sys.path = old_sys_path

def abort_callback(self, topic, data, databus):
print("\n ===== calling abort_callback begin")
# assert failure here will not cause test to fail
self.assertEqual(topic, TOPIC_ABORT)
print("\n ===== calling abort_callback end")

def test_run_relative_import_scripts(self):
old_sys_path = sys.path
script_args = "--batch_size 4"
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/custom"))
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/server/custom"))
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/site-1/custom"))

try:
script_path = "relative_import_train.py"
wrapper = TaskScriptRunner(
site_name="site-1", script_path=script_path, script_args=script_args, redirect_print_to_log=False
)
self.assertTrue(wrapper.script_full_path.endswith(script_path))
path = os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/site-1/custom")
msg = f"attempted relative import with no known parent package, the relative import is not support. python import is based off the sys.path: {path}"
with pytest.raises(ImportError, match=msg):
# check the ImportError
wrapper.run()
finally:
sys.path = old_sys_path

def test_run_abs_path_scripts(self):
old_sys_path = sys.path
script_args = "--batch_size 4"

sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/site-1/custom"))

try:
# path doesn't exist
script_path = "/foo/dummy/train.py"
with pytest.raises(ValueError, match="script_path='/foo/dummy/train.py' not found"):
wrapper = TaskScriptRunner(
site_name="site-1", script_path=script_path, script_args=script_args, redirect_print_to_log=False
)
finally:
sys.path = old_sys_path

def test_run_abs_path_scripts2(self):
old_sys_path = sys.path
script_args = "--batch_size 4"
sys.path.append(os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/site-1/custom"))

try:
script_path = os.path.join(os.getcwd(), "tests/unit_test/data/jobs/in_proc_job/site-1/custom/train.py")
wrapper = TaskScriptRunner(
site_name="site-1", script_path=script_path, script_args=script_args, redirect_print_to_log=False
)
wrapper.run()
finally:
sys.path = old_sys_path
Loading

0 comments on commit d050cf2

Please sign in to comment.