diff --git a/python/orca/src/bigdl/orca/ray/util/raycontext.py b/python/orca/src/bigdl/orca/ray/util/raycontext.py index 01a1c0d14ae..b9704619a3f 100755 --- a/python/orca/src/bigdl/orca/ray/util/raycontext.py +++ b/python/orca/src/bigdl/orca/ray/util/raycontext.py @@ -80,7 +80,7 @@ def _prepare_env(self): return modified_env def __init__(self, python_loc, redis_port, ray_node_cpu_cores, - password, object_store_memory, waitting_time_sec=6, verbose=False, env=None, + password, object_store_memory, verbose=False, env=None, extra_params=None): """object_store_memory: integer in bytes""" self.env = env @@ -90,7 +90,6 @@ def __init__(self, python_loc, redis_port, ray_node_cpu_cores, self.ray_node_cpu_cores = ray_node_cpu_cores self.ray_exec = self._get_ray_exec() self.object_store_memory = object_store_memory - self.waiting_time_sec = waitting_time_sec self.extra_params = extra_params self.verbose = verbose # _mxnet_worker and _mxnet_server are resource tags for distributed MXNet training only @@ -142,16 +141,12 @@ def _get_raylet_command(redis_address, object_store_memory=object_store_memory, extra_params=extra_params) - def _start_ray_node(self, command, tag, wait_before=5, wait_after=5): + def _start_ray_node(self, command, tag): modified_env = self._prepare_env() print("Starting {} by running: {}".format(tag, command)) - print("Wait for {} sec before launching {}".format(wait_before, tag)) - time.sleep(wait_before) process_info = session_execute(command=command, env=modified_env, tag=tag) JVMGuard.registerPids(process_info.pids) process_info.node_ip = rservices.get_node_ip_address() - print("Wait for {} sec before return process info for {}".format(wait_after, tag)) - time.sleep(wait_after) return process_info def _get_ray_exec(self): @@ -169,14 +164,15 @@ def _start_ray_services(iter): print("current address {}".format(task_addrs[tc.partitionId()])) print("master address {}".format(master_ip)) redis_address = "{}:{}".format(master_ip, self.redis_port) + process_info = None if tc.partitionId() == 0: print("partition id is : {}".format(tc.partitionId())) process_info = self._start_ray_node(command=self._gen_master_command(), - tag="ray-master", - wait_after=self.waiting_time_sec) + tag="ray-master") process_info.master_addr = redis_address - yield process_info - else: + + tc.barrier() + if tc.partitionId() != 0: print("partition id is : {}".format(tc.partitionId())) process_info = self._start_ray_node( command=RayServiceFuncGenerator._get_raylet_command( @@ -187,17 +183,14 @@ def _start_ray_services(iter): labels=self.labels, object_store_memory=self.object_store_memory, extra_params=self.extra_params), - tag="raylet", - wait_before=self.waiting_time_sec) - yield process_info - tc.barrier() - + tag="raylet") + yield process_info return _start_ray_services class RayContext(object): def __init__(self, sc, redis_port=None, password="123456", object_store_memory=None, - verbose=False, env=None, local_ray_node_num=2, waiting_time_sec=8, + verbose=False, env=None, local_ray_node_num=2, extra_params=None): """ The RayContext would init a ray cluster on top of the configuration of SparkContext. @@ -236,7 +229,6 @@ def __init__(self, sc, redis_port=None, password="123456", object_store_memory=N object_store_memory=self._enrich_object_sotre_memory(sc, object_store_memory), verbose=verbose, env=env, - waitting_time_sec=waiting_time_sec, extra_params=extra_params) self._gather_cluster_ips() from bigdl.util.common import init_executor_gateway