Skip to content

Commit

Permalink
Support tensorflow 2.13
Browse files Browse the repository at this point in the history
  • Loading branch information
Shelnutt2 committed Jul 23, 2023
1 parent d7d67fd commit 5a3bb71
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions tiledb/ml/models/tensorflow_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,33 @@

from ._base import Meta, TileDBArtifact, Timestamp

FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
keras_major, keras_minor, keras_patch = keras.__version__.split(".")
FunctionalOrSequential = keras.models.Sequential
# Handle keras <=v2.10
if int(keras_major) <= 2 and int(keras_minor) <= 10:
FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
TFOptimizer = keras.optimizers.TFOptimizer
get_json_type = keras.saving.saved_model.json_utils.get_json_type
preprocess_weights_for_loading = (
keras.saving.hdf5_format.preprocess_weights_for_loading
)
saving_utils = keras.saving.saving_utils
# Handle keras >=v2.11
else:
elif int(keras_major) <= 2 and int(keras_minor) <= 12:
FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
TFOptimizer = tf.keras.optimizers.legacy.Optimizer
get_json_type = keras.saving.legacy.saved_model.json_utils.get_json_type
preprocess_weights_for_loading = (
keras.saving.legacy.hdf5_format.preprocess_weights_for_loading
)
saving_utils = keras.saving.legacy.saving_utils
else:
TFOptimizer = tf.keras.optimizers.legacy.Optimizer
get_json_type = keras.src.saving.legacy.saved_model.json_utils.get_json_type
preprocess_weights_for_loading = (
keras.src.saving.legacy.hdf5_format.preprocess_weights_for_loading
)
saving_utils = keras.src.saving.legacy.saving_utils


class TensorflowKerasTileDBModel(TileDBArtifact[tf.keras.Model]):
Expand Down

0 comments on commit 5a3bb71

Please sign in to comment.