Skip to content

Commit

Permalink
Add optim method option in TFOptimizer.from_keras (intel-analytics#1574)
Browse files Browse the repository at this point in the history
  • Loading branch information
hkvision authored and yangw1234 committed Sep 24, 2021
1 parent df39e3b commit 5e8c3de
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/orca/src/bigdl/orca/net/tf_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def from_loss(cls, loss, optim_method, session=None, val_outputs=None,
clip_value=clip_value, **kwargs)

@classmethod
def from_keras(cls, keras_model, dataset, val_spilt=0.0, **kwargs):
def from_keras(cls, keras_model, dataset, optim_method=None, val_spilt=0.0, **kwargs):
import tensorflow.keras.backend as K
loss = keras_model.total_loss
inputs = keras_model.inputs + keras_model.targets
Expand All @@ -290,7 +290,9 @@ def from_keras(cls, keras_model, dataset, val_spilt=0.0, **kwargs):
clip_value = (-keras_optimizer.clipvalue, keras_optimizer.clipvalue)

sess = K.get_session()
optim_method = TFOptimizer.to_bigdl_optim_method(keras_optimizer)
if optim_method is None:
optim_method = keras_optimizer
optim_method = TFOptimizer.to_bigdl_optim_method(optim_method)

if keras_model.metrics and (dataset.get_validation_data() is not None or val_spilt != 0.0):
if isinstance(keras_model.metrics, dict):
Expand Down

0 comments on commit 5e8c3de

Please sign in to comment.