Skip to content

Commit

Permalink
Add ray daemon to kill ray processes (intel-analytics#4571)
Browse files Browse the repository at this point in the history
* add ray daemon

* remove in bigdl

* add ray daemon in start_restricted_worker

* change to static method

* remove ProcessMonitor.register_shutdown_hook and clean_fn

* change name

* clean useless code

* add license
  • Loading branch information
shanyu-sys committed Sep 8, 2021
1 parent 753c0ac commit e74eca4
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 88 deletions.
35 changes: 1 addition & 34 deletions python/orca/src/bigdl/orca/ray/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@

import os
import subprocess
import signal
import atexit
import sys

from zoo.ray.utils import gen_shutdown_per_node, is_local
from zoo.ray.utils import is_local


class ProcessInfo(object):
Expand Down Expand Up @@ -105,7 +101,6 @@ def __init__(self, process_infos, sc, ray_rdd, raycontext, verbose=False):
self.master.append(process_info)
else:
self.slaves.append(process_info)
ProcessMonitor.register_shutdown_hook(extra_close_fn=self.clean_fn)
assert len(self.master) == 1, \
"We should got 1 master only, but we got {}".format(len(self.master))
self.master = self.master[0]
Expand All @@ -122,31 +117,3 @@ def print_ray_remote_err_out(self):
print(self.master)
for slave in self.slaves:
print(slave)

def clean_fn(self):
if not self.raycontext.initialized:
return
import ray
ray.shutdown()
if not self.sc:
print("WARNING: SparkContext has been stopped before cleaning the Ray resources")
if self.sc and (not is_local(self.sc)):
self.ray_rdd.map(gen_shutdown_per_node(self.pgids, self.node_ips)).collect()
else:
gen_shutdown_per_node(self.pgids, self.node_ips)([])

@staticmethod
def register_shutdown_hook(pgid=None, extra_close_fn=None):
def _shutdown():
if pgid:
gen_shutdown_per_node(pgid)(0)
if extra_close_fn:
extra_close_fn()

def _signal_shutdown(_signo, _stack_frame):
_shutdown()
sys.exit(0)

atexit.register(_shutdown)
signal.signal(signal.SIGTERM, _signal_shutdown)
signal.signal(signal.SIGINT, _signal_shutdown)
44 changes: 44 additions & 0 deletions python/orca/src/bigdl/orca/ray/ray_daemon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import sys
import signal
import psutil
import logging
logging.basicConfig(filename='daemon.log', level=logging.INFO)


def stop(pgid):
logging.info(f"Stopping pgid {pgid} by ray_daemon.")
try:
# SIGTERM may not kill all the children processes in the group.
os.killpg(pgid, signal.SIGKILL)
except Exception:
logging.error("Cannot kill pgid: {}".format(pgid))


def manager():
pid_to_watch = int(sys.argv[1])
pgid_to_kill = int(sys.argv[2])
import time
while psutil.pid_exists(pid_to_watch):
time.sleep(1)
stop(pgid_to_kill)


if __name__ == "__main__":
manager()
58 changes: 27 additions & 31 deletions python/orca/src/bigdl/orca/ray/raycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

import os
import re
import subprocess
import time
import uuid
import random
import signal
import warnings
import tempfile
import filelock
Expand All @@ -29,26 +30,7 @@
from zoo.ray.process import session_execute, ProcessMonitor
from zoo.ray.utils import is_local
from zoo.ray.utils import resource_to_bytes


class JVMGuard:
"""
The process group id would be registered and killed in the shutdown hook of Spark Executor.
"""
@staticmethod
def register_pgid(pgid):
import traceback
try:
from zoo.common.utils import callZooFunc
import zoo
callZooFunc("float",
"jvmGuardRegisterPgid",
pgid)
except Exception as err:
print(traceback.format_exc())
print("Cannot successfully register pid into JVMGuard")
os.killpg(pgid, signal.SIGKILL)
raise err
from zoo.ray.utils import get_parent_pid


def kill_redundant_log_monitors(redis_address):
Expand Down Expand Up @@ -200,11 +182,31 @@ def _get_raylet_command(redis_address,
object_store_memory=object_store_memory,
extra_params=extra_params)

@staticmethod
def _get_spark_executor_pid():
# TODO: This might not work on OS other than Linux
this_pid = os.getpid()
pyspark_daemon_pid = get_parent_pid(this_pid)
spark_executor_pid = get_parent_pid(pyspark_daemon_pid)
return spark_executor_pid

@staticmethod
def start_ray_daemon(python_loc, pid_to_watch, pgid_to_kill):
daemon_path = os.path.join(os.path.dirname(__file__), "ray_daemon.py")
start_daemon_command = ['nohup', python_loc, daemon_path, str(pid_to_watch),
str(pgid_to_kill)]
# put ray daemon process in its children's process group to avoid being killed by spark.
subprocess.Popen(start_daemon_command, preexec_fn=os.setpgrp)
time.sleep(1)

def _start_ray_node(self, command, tag):
modified_env = self._prepare_env()
print("Starting {} by running: {}".format(tag, command))
process_info = session_execute(command=command, env=modified_env, tag=tag)
JVMGuard.register_pgid(process_info.pgid)
spark_executor_pid = RayServiceFuncGenerator._get_spark_executor_pid()
RayServiceFuncGenerator.start_ray_daemon(self.python_loc,
pid_to_watch=spark_executor_pid,
pgid_to_kill=process_info.pgid)
import ray._private.services as rservices
process_info.node_ip = rservices.get_node_ip_address()
return process_info
Expand Down Expand Up @@ -474,11 +476,6 @@ def stop(self):
return
import ray
ray.shutdown()
if not self.is_local:
if not self.ray_processesMonitor:
print("Please start the runner first before closing it")
else:
self.ray_processesMonitor.clean_fn()
self.initialized = False

def purge(self):
Expand Down Expand Up @@ -540,9 +537,6 @@ def init(self, driver_cores=0):
self._address_info = ray.init(**init_params)
else:
self.cluster_ips = self._gather_cluster_ips()
from bigdl.util.common import init_executor_gateway
init_executor_gateway(self.sc)
print("JavaGatewayServer has been successfully launched on executors")
redis_address = self._start_cluster()
self._address_info = self._start_driver(num_cores=driver_cores,
redis_address=redis_address)
Expand Down Expand Up @@ -609,7 +603,9 @@ def _start_restricted_worker(self, num_cores, node_ip_address, redis_address):
print("Executing command: {}".format(command))
process_info = session_execute(command=command, env=modified_env,
tag="raylet", fail_fast=True)
ProcessMonitor.register_shutdown_hook(pgid=process_info.pgid)
RayServiceFuncGenerator.start_ray_daemon("python",
pid_to_watch=os.getpid(),
pgid_to_kill=process_info.pgid)

def _start_driver(self, num_cores, redis_address):
print("Start to launch ray driver on local")
Expand Down
29 changes: 6 additions & 23 deletions python/orca/src/bigdl/orca/ray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
#

import re
import os
import signal
import psutil


def to_list(input):
Expand Down Expand Up @@ -54,27 +53,11 @@ def resource_to_bytes(resource_str):
"E.g. 50b, 100k, 250m, 30g")


def gen_shutdown_per_node(pgids, node_ips=None):
import ray._private.services as rservices
pgids = to_list(pgids)

def _shutdown_per_node(iter):
print("Stopping pgids: {}".format(pgids))
if node_ips:
current_node_ip = rservices.get_node_ip_address()
effect_pgids = [pair[0] for pair in zip(pgids, node_ips) if pair[1] == current_node_ip]
else:
effect_pgids = pgids
for pgid in effect_pgids:
print("Stopping by pgid {}".format(pgid))
try:
os.killpg(pgid, signal.SIGTERM)
except Exception:
print("WARNING: cannot kill pgid: {}".format(pgid))

return _shutdown_per_node


def is_local(sc):
master = sc.getConf().get("spark.master")
return master == "local" or master.startswith("local[")


def get_parent_pid(pid):
cur_proc = psutil.Process(pid)
return cur_proc.ppid()

0 comments on commit e74eca4

Please sign in to comment.