diff --git a/python/orca/src/bigdl/orca/data/ray_xshards.py b/python/orca/src/bigdl/orca/data/ray_xshards.py index f78cea0fb75..bdfbb8e431f 100644 --- a/python/orca/src/bigdl/orca/data/ray_xshards.py +++ b/python/orca/src/bigdl/orca/data/ray_xshards.py @@ -16,7 +16,7 @@ from collections import defaultdict import ray -import ray.services +import ray._private.services import uuid import random @@ -67,8 +67,8 @@ def get_partitions(self): def write_to_ray(idx, partition, redis_address, redis_password, partition_store_names): if not ray.is_initialized(): - ray.init(address=redis_address, redis_password=redis_password, ignore_reinit_error=True) - ip = ray.services.get_node_ip_address() + ray.init(address=redis_address, _redis_password=redis_password, ignore_reinit_error=True) + ip = ray._private.services.get_node_ip_address() local_store_name = None for name in partition_store_names: if name.endswith(ip): @@ -77,7 +77,7 @@ def write_to_ray(idx, partition, redis_address, redis_password, partition_store_ if local_store_name is None: local_store_name = random.choice(partition_store_names) - local_store = ray.util.get_actor(local_store_name) + local_store = ray.get_actor(local_store_name) # directly calling ray.put will set this driver as the owner of this object, # when the spark job finished, the driver might exit and make the object @@ -87,17 +87,15 @@ def write_to_ray(idx, partition, redis_address, redis_password, partition_store_ shard_ref = ray.put(shard) result.append(local_store.upload_shards.remote((idx, shard_id), shard_ref)) ray.get(result) - ray.shutdown() return [(idx, local_store_name.split(":")[-1], local_store_name)] def get_from_ray(idx, redis_address, redis_password, idx_to_store_name): if not ray.is_initialized(): - ray.init(address=redis_address, redis_password=redis_password, ignore_reinit_error=True) - local_store_handle = ray.util.get_actor(idx_to_store_name[idx]) + ray.init(address=redis_address, _redis_password=redis_password, ignore_reinit_error=True) + local_store_handle = ray.get_actor(idx_to_store_name[idx]) partition = ray.get(local_store_handle.get_partition.remote(idx)) - ray.shutdown() return partition @@ -141,7 +139,13 @@ def to_spark_xshards(self): rdd = sc.parallelize([0] * num_parts * 10, num_parts)\ .mapPartitionsWithIndex( lambda idx, _: get_from_ray(idx, address, password, partition2store)) - spark_xshards = SparkXShards(rdd) + + # the reason why we trigger computation here is to ensure we get the data + # from ray before the RayXShards goes out of scope and the data get garbage collected + from pyspark.storagelevel import StorageLevel + rdd = rdd.cache() + result_rdd = rdd.map(lambda x: x) # sparkxshards will uncache the rdd when gc + spark_xshards = SparkXShards(result_rdd) return spark_xshards def _get_multiple_partition_refs(self, ids): @@ -159,7 +163,7 @@ def transform_shards_with_actors(self, actors, func, and run func for each actor and partition_ref pair. Actors should have a `get_node_ip` method to achieve locality scheduling. - The `get_node_ip` method should call ray.services.get_node_ip_address() + The `get_node_ip` method should call ray._private.services.get_node_ip_address() to return the correct ip address. The `func` should take an actor and a partition_ref as argument and @@ -304,7 +308,7 @@ def _from_spark_xshards_ray_api(spark_xshards): ray_ctx = RayContext.get() address = ray_ctx.redis_address password = ray_ctx.redis_password - driver_ip = ray.services.get_node_ip_address() + driver_ip = ray._private.services.get_node_ip_address() uuid_str = str(uuid.uuid4()) resources = ray.cluster_resources() nodes = [] @@ -320,6 +324,9 @@ def _from_spark_xshards_ray_api(spark_xshards): store = ray.remote(num_cpus=0, resources={node: 1e-4})(LocalStore)\ .options(name=name).remote() partition_stores[name] = store + + # actor creation is aync, this is to make sure they all have been started + ray.get([v.get_partitions.remote() for v in partition_stores.values()]) partition_store_names = list(partition_stores.keys()) result = spark_xshards.rdd.mapPartitionsWithIndex(lambda idx, part: write_to_ray( idx, part, address, password, partition_store_names)).collect() diff --git a/python/orca/src/bigdl/orca/data/shard.py b/python/orca/src/bigdl/orca/data/shard.py index dfdf8b0815d..bb67e677dfe 100644 --- a/python/orca/src/bigdl/orca/data/shard.py +++ b/python/orca/src/bigdl/orca/data/shard.py @@ -114,7 +114,6 @@ def partition(data, num_shards=None): data_shards = SparkXShards(rdd) return data_shards - class SparkXShards(XShards): """ A collection of data which can be pre-processed in parallel on Spark diff --git a/python/orca/src/bigdl/orca/learn/horovod/horovod_ray_runner.py b/python/orca/src/bigdl/orca/learn/horovod/horovod_ray_runner.py index 8fbcb55b60a..ff0118f3f24 100644 --- a/python/orca/src/bigdl/orca/learn/horovod/horovod_ray_runner.py +++ b/python/orca/src/bigdl/orca/learn/horovod/horovod_ray_runner.py @@ -22,7 +22,7 @@ class HorovodWorker: def ip_addr(self): import ray - return ray.services.get_node_ip_address() + return ray._private.services.get_node_ip_address() def set_gloo_iface(self): ip_addr = self.ip_addr() @@ -111,7 +111,7 @@ def __init__(self, ray_ctx, worker_cls=None, worker_param=None, workers_per_node global_rendezv_port = self.global_rendezv.start() self.global_rendezv.init(self.host_alloc_plan) - driver_ip = ray.services.get_node_ip_address() + driver_ip = ray._private.services.get_node_ip_address() common_envs = { "HOROVOD_GLOO_RENDEZVOUS_ADDR": driver_ip, diff --git a/python/orca/src/bigdl/orca/learn/mxnet/mxnet_runner.py b/python/orca/src/bigdl/orca/learn/mxnet/mxnet_runner.py index 057bb3895aa..67723d59831 100644 --- a/python/orca/src/bigdl/orca/learn/mxnet/mxnet_runner.py +++ b/python/orca/src/bigdl/orca/learn/mxnet/mxnet_runner.py @@ -18,7 +18,7 @@ import time import logging import subprocess -import ray.services +import ray._private.services import mxnet as mx from mxnet import gluon from zoo.ray.utils import to_list @@ -214,7 +214,7 @@ def shutdown(self): def get_node_ip(self): """Returns the IP address of the current node.""" if "node_ip" not in self.__dict__: - self.node_ip = ray.services.get_node_ip_address() + self.node_ip = ray._private.services.get_node_ip_address() return self.node_ip def find_free_port(self): diff --git a/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py b/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py index 874d152c8d0..50e2f9919c1 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py @@ -129,7 +129,7 @@ def setup_horovod(self): self.setup_operator(self.models) def setup_address(self): - ip = ray.services.get_node_ip_address() + ip = ray._private.services.get_node_ip_address() port = find_free_port() return f"tcp://{ip}:{port}" @@ -213,7 +213,7 @@ def setup_operator(self, training_models): def get_node_ip(self): """Returns the IP address of the current node.""" - return ray.services.get_node_ip_address() + return ray._private.services.get_node_ip_address() def find_free_port(self): """Finds a free port on the current node.""" diff --git a/python/orca/src/bigdl/orca/learn/tf2/tf_runner.py b/python/orca/src/bigdl/orca/learn/tf2/tf_runner.py index a1aa0fbef95..6e51824ca3e 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/tf_runner.py +++ b/python/orca/src/bigdl/orca/learn/tf2/tf_runner.py @@ -35,7 +35,6 @@ import numpy as np import ray -import ray.services from contextlib import closing import logging import socket @@ -473,7 +472,7 @@ def shutdown(self): def get_node_ip(self): """Returns the IP address of the current node.""" - return ray.services.get_node_ip_address() + return ray._private.services.get_node_ip_address() def find_free_port(self): """Finds a free port on the current node.""" diff --git a/python/orca/test/bigdl/orca/data/test_ray_xshards.py b/python/orca/test/bigdl/orca/data/test_ray_xshards.py index 2f1f7aca8fa..a0cb3ab92b3 100644 --- a/python/orca/test/bigdl/orca/data/test_ray_xshards.py +++ b/python/orca/test/bigdl/orca/data/test_ray_xshards.py @@ -56,7 +56,7 @@ class Add1Actor: def get_node_ip(self): import ray - return ray.services.get_node_ip_address() + return ray._private.services.get_node_ip_address() def add_one(self, partition): return [{k: (value + 1) for k, value in shards.items()} for shards in partition]