Skip to content

Commit

Permalink
support session_config in estimator and refine evalute (intel-analyti…
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 committed Aug 3, 2020
1 parent debaa6a commit cd17a15
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions python/orca/src/bigdl/orca/tfpark/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _call_model_fn(self, features, labels, mode, config):

return model_fn_results

def train(self, input_fn, steps=None):
def train(self, input_fn, steps=None, session_config=None):
"""Trains a model given training data `input_fn`.
:param input_fn: A function that constructs the input data for evaluation. The
Expand Down Expand Up @@ -163,7 +163,8 @@ def train(self, input_fn, steps=None):
spec.loss,
sess=sess,
dataset=result,
model_dir=zoo_ckpt_path)
model_dir=zoo_ckpt_path,
session_config=session_config)

opt.optimize(MaxIteration(steps))
sess.run(assign_step, feed_dict={add_step_input: steps})
Expand Down Expand Up @@ -223,9 +224,17 @@ def evaluate(self, input_fn, eval_methods, steps=None, checkpoint_path=None):
sess.run(tf.global_variables_initializer())
if isinstance(spec.predictions, dict):
if "mae" in eval_methods:
key = prediction_keys.PredictionKeys.PREDICTIONS
msg = "{} is required for evaluating mse,".format(key) + \
" please add it in your model_fn predictions"
assert key in spec.prediction, msg
outputs = [
spec.predictions[prediction_keys.PredictionKeys.PREDICTIONS]]
else:
key = prediction_keys.PredictionKeys.LOGITS
msg = "{} is required in for evaluating,".format(key) + \
" please add it in your model_fn predictions"
assert key in spec.predictions, msg
outputs = [
spec.predictions[prediction_keys.PredictionKeys.LOGITS]]
else:
Expand Down

0 comments on commit cd17a15

Please sign in to comment.