Skip to content

Commit

Permalink
Fix tf model persistor and tf model (NVIDIA#2984)
Browse files Browse the repository at this point in the history
* add missing filter id arg in tf model persistor

* Update TFModel

* Address comment
  • Loading branch information
YuanTingHsieh committed Oct 3, 2024
1 parent 80b611c commit 56ee91b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
30 changes: 21 additions & 9 deletions nvflare/app_opt/tf/job_config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import tensorflow as tf

from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_opt.tf.model_persistor import TFModelPersistor
from nvflare.job_config.api import validate_object_for_job


class TFModel:
def __init__(self, model):
"""TensorFLow model wrapper.
def __init__(self, model, persistor: Optional[ModelPersistor] = None):
"""TensorFlow model wrapper.
If model is a tf.keras.Model, add a TFModelPersistor with the model.
If persistor is provided, use it.
Else if model is a tf.keras.Model, add a TFModelPersistor with the model.
Args:
model (any): model
persistor (Optional[ModelPersistor]): A ModelPersistor,
if provided will ignore argument `model`, defaults to None.
"""
self.model = model

if self.persistor:
validate_object_for_job("persistor", persistor, ModelPersistor)
self.persistor = persistor

def add_to_fed_job(self, job, ctx):
"""This method is used by Job API.
Expand All @@ -38,11 +49,12 @@ def add_to_fed_job(self, job, ctx):
Returns:
dictionary of ids of component added
"""
if isinstance(self.model, tf.keras.Model): # if model, create a TF persistor
if self.persistor:
persistor = self.persistor
elif isinstance(self.model, tf.keras.Model):
# if model is a tf.keras.Model, creates a TFModelPersistor
persistor = TFModelPersistor(model=self.model)
persistor_id = job.add_component(comp_id="persistor", obj=persistor, ctx=ctx)
return persistor_id
else:
raise ValueError(
f"Unable to add {self.model} to job with TFModelPersistor. Expected tf.keras.Model but got {type(self.model)}."
)
raise ValueError(f"Unsupported type for model: {type(self.model)}.")
persistor_id = job.add_component(comp_id="persistor", obj=persistor, ctx=ctx)
return persistor_id
6 changes: 4 additions & 2 deletions nvflare/app_opt/tf/model_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@


class TFModelPersistor(ModelPersistor):
def __init__(self, model: tf.keras.Model, save_name="tf_model.weights.h5"):
super().__init__()
def __init__(self, model: tf.keras.Model, save_name="tf_model.weights.h5", filter_id: str = None):
super().__init__(
filter_id=filter_id,
)
self.save_name = save_name
self.model = model

Expand Down

0 comments on commit 56ee91b

Please sign in to comment.