diff --git a/python/orca/src/bigdl/orca/tfpark/estimator.py b/python/orca/src/bigdl/orca/tfpark/estimator.py index 552ada089c3..e7f4c7839c4 100644 --- a/python/orca/src/bigdl/orca/tfpark/estimator.py +++ b/python/orca/src/bigdl/orca/tfpark/estimator.py @@ -115,7 +115,7 @@ def _call_model_fn(self, features, labels, mode, config): return model_fn_results - def train(self, input_fn, steps=None): + def train(self, input_fn, steps=None, session_config=None): """Trains a model given training data `input_fn`. :param input_fn: A function that constructs the input data for evaluation. The @@ -163,7 +163,8 @@ def train(self, input_fn, steps=None): spec.loss, sess=sess, dataset=result, - model_dir=zoo_ckpt_path) + model_dir=zoo_ckpt_path, + session_config=session_config) opt.optimize(MaxIteration(steps)) sess.run(assign_step, feed_dict={add_step_input: steps}) @@ -223,9 +224,17 @@ def evaluate(self, input_fn, eval_methods, steps=None, checkpoint_path=None): sess.run(tf.global_variables_initializer()) if isinstance(spec.predictions, dict): if "mae" in eval_methods: + key = prediction_keys.PredictionKeys.PREDICTIONS + msg = "{} is required for evaluating mse,".format(key) + \ + " please add it in your model_fn predictions" + assert key in spec.prediction, msg outputs = [ spec.predictions[prediction_keys.PredictionKeys.PREDICTIONS]] else: + key = prediction_keys.PredictionKeys.LOGITS + msg = "{} is required in for evaluating,".format(key) + \ + " please add it in your model_fn predictions" + assert key in spec.predictions, msg outputs = [ spec.predictions[prediction_keys.PredictionKeys.LOGITS]] else: