diff --git a/python/orca/src/bigdl/orca/ray/process.py b/python/orca/src/bigdl/orca/ray/process.py index ccd8d0b3fd3..0da6cf54574 100644 --- a/python/orca/src/bigdl/orca/ray/process.py +++ b/python/orca/src/bigdl/orca/ray/process.py @@ -16,11 +16,7 @@ import os import subprocess -import signal -import atexit -import sys - -from zoo.ray.utils import gen_shutdown_per_node, is_local +from zoo.ray.utils import is_local class ProcessInfo(object): @@ -105,7 +101,6 @@ def __init__(self, process_infos, sc, ray_rdd, raycontext, verbose=False): self.master.append(process_info) else: self.slaves.append(process_info) - ProcessMonitor.register_shutdown_hook(extra_close_fn=self.clean_fn) assert len(self.master) == 1, \ "We should got 1 master only, but we got {}".format(len(self.master)) self.master = self.master[0] @@ -122,31 +117,3 @@ def print_ray_remote_err_out(self): print(self.master) for slave in self.slaves: print(slave) - - def clean_fn(self): - if not self.raycontext.initialized: - return - import ray - ray.shutdown() - if not self.sc: - print("WARNING: SparkContext has been stopped before cleaning the Ray resources") - if self.sc and (not is_local(self.sc)): - self.ray_rdd.map(gen_shutdown_per_node(self.pgids, self.node_ips)).collect() - else: - gen_shutdown_per_node(self.pgids, self.node_ips)([]) - - @staticmethod - def register_shutdown_hook(pgid=None, extra_close_fn=None): - def _shutdown(): - if pgid: - gen_shutdown_per_node(pgid)(0) - if extra_close_fn: - extra_close_fn() - - def _signal_shutdown(_signo, _stack_frame): - _shutdown() - sys.exit(0) - - atexit.register(_shutdown) - signal.signal(signal.SIGTERM, _signal_shutdown) - signal.signal(signal.SIGINT, _signal_shutdown) diff --git a/python/orca/src/bigdl/orca/ray/ray_daemon.py b/python/orca/src/bigdl/orca/ray/ray_daemon.py new file mode 100644 index 00000000000..fe5ff7533ea --- /dev/null +++ b/python/orca/src/bigdl/orca/ray/ray_daemon.py @@ -0,0 +1,44 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import signal +import psutil +import logging +logging.basicConfig(filename='daemon.log', level=logging.INFO) + + +def stop(pgid): + logging.info(f"Stopping pgid {pgid} by ray_daemon.") + try: + # SIGTERM may not kill all the children processes in the group. + os.killpg(pgid, signal.SIGKILL) + except Exception: + logging.error("Cannot kill pgid: {}".format(pgid)) + + +def manager(): + pid_to_watch = int(sys.argv[1]) + pgid_to_kill = int(sys.argv[2]) + import time + while psutil.pid_exists(pid_to_watch): + time.sleep(1) + stop(pgid_to_kill) + + +if __name__ == "__main__": + manager() diff --git a/python/orca/src/bigdl/orca/ray/raycontext.py b/python/orca/src/bigdl/orca/ray/raycontext.py index b6f339bafba..28ff4660a6f 100755 --- a/python/orca/src/bigdl/orca/ray/raycontext.py +++ b/python/orca/src/bigdl/orca/ray/raycontext.py @@ -17,9 +17,10 @@ import os import re +import subprocess +import time import uuid import random -import signal import warnings import tempfile import filelock @@ -29,26 +30,7 @@ from zoo.ray.process import session_execute, ProcessMonitor from zoo.ray.utils import is_local from zoo.ray.utils import resource_to_bytes - - -class JVMGuard: - """ - The process group id would be registered and killed in the shutdown hook of Spark Executor. - """ - @staticmethod - def register_pgid(pgid): - import traceback - try: - from zoo.common.utils import callZooFunc - import zoo - callZooFunc("float", - "jvmGuardRegisterPgid", - pgid) - except Exception as err: - print(traceback.format_exc()) - print("Cannot successfully register pid into JVMGuard") - os.killpg(pgid, signal.SIGKILL) - raise err +from zoo.ray.utils import get_parent_pid def kill_redundant_log_monitors(redis_address): @@ -200,11 +182,31 @@ def _get_raylet_command(redis_address, object_store_memory=object_store_memory, extra_params=extra_params) + @staticmethod + def _get_spark_executor_pid(): + # TODO: This might not work on OS other than Linux + this_pid = os.getpid() + pyspark_daemon_pid = get_parent_pid(this_pid) + spark_executor_pid = get_parent_pid(pyspark_daemon_pid) + return spark_executor_pid + + @staticmethod + def start_ray_daemon(python_loc, pid_to_watch, pgid_to_kill): + daemon_path = os.path.join(os.path.dirname(__file__), "ray_daemon.py") + start_daemon_command = ['nohup', python_loc, daemon_path, str(pid_to_watch), + str(pgid_to_kill)] + # put ray daemon process in its children's process group to avoid being killed by spark. + subprocess.Popen(start_daemon_command, preexec_fn=os.setpgrp) + time.sleep(1) + def _start_ray_node(self, command, tag): modified_env = self._prepare_env() print("Starting {} by running: {}".format(tag, command)) process_info = session_execute(command=command, env=modified_env, tag=tag) - JVMGuard.register_pgid(process_info.pgid) + spark_executor_pid = RayServiceFuncGenerator._get_spark_executor_pid() + RayServiceFuncGenerator.start_ray_daemon(self.python_loc, + pid_to_watch=spark_executor_pid, + pgid_to_kill=process_info.pgid) import ray._private.services as rservices process_info.node_ip = rservices.get_node_ip_address() return process_info @@ -474,11 +476,6 @@ def stop(self): return import ray ray.shutdown() - if not self.is_local: - if not self.ray_processesMonitor: - print("Please start the runner first before closing it") - else: - self.ray_processesMonitor.clean_fn() self.initialized = False def purge(self): @@ -540,9 +537,6 @@ def init(self, driver_cores=0): self._address_info = ray.init(**init_params) else: self.cluster_ips = self._gather_cluster_ips() - from bigdl.util.common import init_executor_gateway - init_executor_gateway(self.sc) - print("JavaGatewayServer has been successfully launched on executors") redis_address = self._start_cluster() self._address_info = self._start_driver(num_cores=driver_cores, redis_address=redis_address) @@ -609,7 +603,9 @@ def _start_restricted_worker(self, num_cores, node_ip_address, redis_address): print("Executing command: {}".format(command)) process_info = session_execute(command=command, env=modified_env, tag="raylet", fail_fast=True) - ProcessMonitor.register_shutdown_hook(pgid=process_info.pgid) + RayServiceFuncGenerator.start_ray_daemon("python", + pid_to_watch=os.getpid(), + pgid_to_kill=process_info.pgid) def _start_driver(self, num_cores, redis_address): print("Start to launch ray driver on local") diff --git a/python/orca/src/bigdl/orca/ray/utils.py b/python/orca/src/bigdl/orca/ray/utils.py index cdaa06ce3ee..de09cb6f3db 100644 --- a/python/orca/src/bigdl/orca/ray/utils.py +++ b/python/orca/src/bigdl/orca/ray/utils.py @@ -15,8 +15,7 @@ # import re -import os -import signal +import psutil def to_list(input): @@ -54,27 +53,11 @@ def resource_to_bytes(resource_str): "E.g. 50b, 100k, 250m, 30g") -def gen_shutdown_per_node(pgids, node_ips=None): - import ray._private.services as rservices - pgids = to_list(pgids) - - def _shutdown_per_node(iter): - print("Stopping pgids: {}".format(pgids)) - if node_ips: - current_node_ip = rservices.get_node_ip_address() - effect_pgids = [pair[0] for pair in zip(pgids, node_ips) if pair[1] == current_node_ip] - else: - effect_pgids = pgids - for pgid in effect_pgids: - print("Stopping by pgid {}".format(pgid)) - try: - os.killpg(pgid, signal.SIGTERM) - except Exception: - print("WARNING: cannot kill pgid: {}".format(pgid)) - - return _shutdown_per_node - - def is_local(sc): master = sc.getConf().get("spark.master") return master == "local" or master.startswith("local[") + + +def get_parent_pid(pid): + cur_proc = psutil.Process(pid) + return cur_proc.ppid()