Skip to content

Commit

Permalink
Fix orca ray pytorch example (intel-analytics#3007)
Browse files Browse the repository at this point in the history
* fix horovod pytorch exampe

* fix bug

* fix process group

* fix style

* fix tests

* fix test

* fix tests

* revert ray context change
  • Loading branch information
yangw1234 committed Oct 28, 2020
1 parent a20a495 commit d55a80c
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions python/orca/src/bigdl/orca/ray/raycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,9 @@ def init(self, driver_cores=0):
from bigdl.util.common import init_executor_gateway
init_executor_gateway(self.sc)
print("JavaGatewayServer has been successfully launched on executors")
self._start_cluster()
self._address_info = self._start_driver(num_cores=driver_cores)
redis_address = self._start_cluster()
self._address_info = self._start_driver(num_cores=driver_cores,
redis_address=redis_address)

print(self._address_info)
kill_redundant_log_monitors(self._address_info["redis_address"])
Expand Down Expand Up @@ -494,14 +495,14 @@ def _start_cluster(self):

self.ray_processesMonitor = ProcessMonitor(process_infos, self.sc, ray_rdd, self,
verbose=self.verbose)
return self
return self.ray_processesMonitor.master.master_addr

def _start_restricted_worker(self, num_cores, node_ip_address):
def _start_restricted_worker(self, num_cores, node_ip_address, redis_address):
extra_param = {"node-ip-address": node_ip_address}
if self.extra_params is not None:
extra_param.update(self.extra_params)
command = RayServiceFuncGenerator._get_raylet_command(
redis_address=self.redis_address,
redis_address=redis_address,
ray_exec="ray",
password=self.redis_password,
ray_node_cpu_cores=num_cores,
Expand All @@ -513,13 +514,14 @@ def _start_restricted_worker(self, num_cores, node_ip_address):
tag="raylet", fail_fast=True)
ProcessMonitor.register_shutdown_hook(pgid=process_info.pgid)

def _start_driver(self, num_cores=0):
def _start_driver(self, num_cores, redis_address):
print("Start to launch ray driver on local")
import ray.services
node_ip = ray.services.get_node_ip_address(self.redis_address)
node_ip = ray.services.get_node_ip_address(redis_address)
self._start_restricted_worker(num_cores=num_cores,
node_ip_address=node_ip)
node_ip_address=node_ip,
redis_address=redis_address)
ray.shutdown()
return ray.init(address=self.redis_address,
return ray.init(address=redis_address,
redis_password=self.ray_service.password,
node_ip_address=node_ip)

0 comments on commit d55a80c

Please sign in to comment.