Skip to content

Commit

Permalink
add ut in pytorch estimator and tf estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
shanyu-sys committed Feb 22, 2021
1 parent 0851ec1 commit e42ece0
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,23 @@ def test_dataframe_train_eval(self):
feature_cols=["feature"],
label_cols=["label"])

def test_dataframe_shard_size_train_eval(self):
from zoo.orca import OrcaContext
OrcaContext.shard_size = 30
sc = init_nncontext()
rdd = sc.range(0, 100)
df = rdd.map(lambda x: (np.random.randn(50).astype(np.float).tolist(),
[int(np.random.randint(0, 2, size=()))])
).toDF(["feature", "label"])

estimator = get_estimator(workers_per_node=2)
estimator.fit(df, batch_size=4, epochs=2,
feature_cols=["feature"],
label_cols=["label"])
estimator.evaluate(df, batch_size=4,
feature_cols=["feature"],
label_cols=["label"])

def test_dataframe_predict(self):

sc = init_nncontext()
Expand Down
27 changes: 27 additions & 0 deletions pyzoo/test/zoo/orca/learn/ray/tf/test_tf_ray_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,33 @@ def test_dataframe(self):
label_cols=["label"])
trainer.predict(df, feature_cols=["feature"]).collect()

def test_dataframe_shard_size(self):
from zoo.orca import OrcaContext
OrcaContext.shard_size = 3
sc = init_nncontext()
rdd = sc.range(0, 10)
from pyspark.sql import SparkSession
spark = SparkSession(sc)
from pyspark.ml.linalg import DenseVector
df = rdd.map(lambda x: (DenseVector(np.random.randn(1,).astype(np.float)),
int(np.random.randint(0, 1, size=())))).toDF(["feature", "label"])

config = {
"lr": 0.8
}
trainer = Estimator.from_keras(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2)

trainer.fit(df, epochs=1, batch_size=4, steps_per_epoch=25,
feature_cols=["feature"],
label_cols=["label"])
trainer.evaluate(df, batch_size=4, num_steps=25, feature_cols=["feature"],
label_cols=["label"])
trainer.predict(df, feature_cols=["feature"]).collect()

def test_dataframe_predict(self):
sc = init_nncontext()
rdd = sc.parallelize(range(20))
Expand Down

0 comments on commit e42ece0

Please sign in to comment.