Skip to content

Commit

Permalink
Allow customization of BaseFedJob (#2985)
Browse files Browse the repository at this point in the history
* Add CommonComponentsJob

* Fix format

* Address comments

* Fix issue
  • Loading branch information
YuanTingHsieh authored Oct 3, 2024
1 parent cb08967 commit f0c6e80
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 36 deletions.
74 changes: 58 additions & 16 deletions nvflare/app_opt/pt/job_config/base_fed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

from torch import nn as nn

from nvflare import FedJob
from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.app_opt.pt.job_config.model import PTModel
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.job_config.api import FedJob, validate_object_for_job


class BaseFedJob(FedJob):
Expand All @@ -32,12 +36,18 @@ def __init__(
min_clients: int = 1,
mandatory_clients: Optional[List[str]] = None,
key_metric: str = "accuracy",
validation_json_generator: Optional[ValidationJsonGenerator] = None,
intime_model_selector: Optional[IntimeModelSelector] = None,
convert_to_fed_event: Optional[ConvertToFedEvent] = None,
analytics_receiver: Optional[AnalyticsReceiver] = None,
model_persistor: Optional[ModelPersistor] = None,
model_locator: Optional[ModelLocator] = None,
):
"""PyTorch BaseFedJob.
Configures server side FedAvg controller, persistor with initial model, and widgets.
Configures ValidationJsonGenerator, IntimeModelSelector, AnalyticsReceiver, ConvertToFedEvent.
User must add executors.
User must add controllers and executors.
Args:
initial_model (nn.Module): initial PyTorch Model. Defaults to None.
Expand All @@ -47,26 +57,58 @@ def __init__(
key_metric (str, optional): Metric used to determine if the model is globally best.
if metrics are a `dict`, `key_metric` can select the metric used for global model selection.
Defaults to "accuracy".
validation_json_generator (ValidationJsonGenerator, optional): A component for generating validation results.
if not provided, a ValidationJsonGenerator will be configured.
intime_model_selector: (IntimeModelSelector, optional): A component for select the model.
if not provided, an IntimeModelSelector will be configured.
convert_to_fed_event: (ConvertToFedEvent, optional): A component to covert certain events to fed events.
if not provided, a ConvertToFedEvent object will be created.
analytics_receiver (AnlyticsReceiver, optional): Receive analytics.
If not provided, a TBAnalyticsReceiver will be configured.
model_persistor (optional, ModelPersistor): how to persistor the model.
model_locator (optional, ModelLocator): how to locate the model.
"""
super().__init__(name, min_clients, mandatory_clients)
self.key_metric = key_metric
super().__init__(
name=name,
min_clients=min_clients,
mandatory_clients=mandatory_clients,
)

self.initial_model = initial_model
self.comp_ids = {}

component = ValidationJsonGenerator()
self.to_server(id="json_generator", obj=component)
if validation_json_generator:
validate_object_for_job("validation_json_generator", validation_json_generator, ValidationJsonGenerator)
else:
validation_json_generator = ValidationJsonGenerator()
self.to_server(id="json_generator", obj=validation_json_generator)

if intime_model_selector:
validate_object_for_job("intime_model_selector", intime_model_selector, IntimeModelSelector)
self.to_server(id="model_selector", obj=intime_model_selector)
elif key_metric:
self.to_server(id="model_selector", obj=IntimeModelSelector(key_metric=key_metric))

if convert_to_fed_event:
validate_object_for_job("convert_to_fed_event", convert_to_fed_event, ConvertToFedEvent)
else:
convert_to_fed_event = ConvertToFedEvent(events_to_convert=[ANALYTIC_EVENT_TYPE])
self.convert_to_fed_event = convert_to_fed_event

if self.key_metric:
component = IntimeModelSelector(key_metric=self.key_metric)
self.to_server(id="model_selector", obj=component)
if analytics_receiver:
validate_object_for_job("analytics_receiver", analytics_receiver, AnalyticsReceiver)
else:
analytics_receiver = TBAnalyticsReceiver()

# TODO: make different tracking receivers configurable
component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"])
self.to_server(id="receiver", obj=component)
self.to_server(
id="receiver",
obj=analytics_receiver,
)

if initial_model:
self.comp_ids.update(self.to_server(PTModel(initial_model)))
self.comp_ids.update(
self.to_server(PTModel(model=initial_model, persistor=model_persistor, locator=model_locator))
)

def set_up_client(self, target: str):
component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.")
self.to(id="event_to_fed", obj=component, target=target)
self.to(id="event_to_fed", obj=self.convert_to_fed_event, target=target)
19 changes: 16 additions & 3 deletions nvflare/app_opt/pt/job_config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch.nn as nn

from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_opt.pt import PTFileModelPersistor
from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator
from nvflare.job_config.api import validate_object_for_job


class PTModel:
def __init__(self, model):
def __init__(self, model, persistor: Optional[ModelPersistor], locator: Optional[ModelLocator]):
"""PyTorch model wrapper.
If model is an nn.Module, add a PTFileModelPersistor with the model and a TFModelPersistor.
Args:
model (any): model
persistor (optional, ModelPersistor): how to persistor the model.
locator (optional, ModelLocator): how to locate the model.
"""
self.model = model
if persistor:
validate_object_for_job("persistor", persistor, ModelPersistor)
self.persistor = persistor
if locator:
validate_object_for_job("locator", locator, ModelLocator)
self.locator = locator

def add_to_fed_job(self, job, ctx):
"""This method is used by Job API.
Expand All @@ -40,10 +53,10 @@ def add_to_fed_job(self, job, ctx):
dictionary of ids of component added
"""
if isinstance(self.model, nn.Module): # if model, create a PT persistor
persistor = PTFileModelPersistor(model=self.model)
persistor = self.persistor if self.persistor else PTFileModelPersistor(model=self.model)
persistor_id = job.add_component(comp_id="persistor", obj=persistor, ctx=ctx)

locator = PTFileModelLocator(pt_persistor_id=persistor_id)
locator = self.locator if self.locator else PTFileModelLocator(pt_persistor_id=persistor_id)
locator_id = job.add_component(comp_id="locator", obj=locator, ctx=ctx)
return {"persistor_id": persistor_id, "locator_id": locator_id}
else:
Expand Down
69 changes: 53 additions & 16 deletions nvflare/app_opt/tf/job_config/base_fed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@

import tensorflow as tf

from nvflare import FedJob
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.app_opt.tf.job_config.model import TFModel
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.job_config.api import FedJob, validate_object_for_job


class BaseFedJob(FedJob):
Expand All @@ -32,12 +35,17 @@ def __init__(
min_clients: int = 1,
mandatory_clients: Optional[List[str]] = None,
key_metric: str = "accuracy",
validation_json_generator: Optional[ValidationJsonGenerator] = None,
intime_model_selector: Optional[IntimeModelSelector] = None,
convert_to_fed_event: Optional[ConvertToFedEvent] = None,
analytics_receiver: Optional[AnalyticsReceiver] = None,
model_persistor: Optional[ModelPersistor] = None,
):
"""TensorFlow BaseFedJob.
Configures server side FedAvg controller, persistor with initial model, and widgets.
Configures ValidationJsonGenerator, IntimeModelSelector, TBAnalyticsReceiver, ConvertToFedEvent.
User must add executors.
User must add controllers and executors.
Args:
initial_model (tf.keras.Model): initial TensorFlow Model. Defaults to None.
Expand All @@ -47,26 +55,55 @@ def __init__(
key_metric (str, optional): Metric used to determine if the model is globally best.
if metrics are a `dict`, `key_metric` can select the metric used for global model selection.
Defaults to "accuracy".
validation_json_generator (ValidationJsonGenerator, optional): A component for generating validation results.
if not provided, a ValidationJsonGenerator will be configured.
intime_model_selector: (IntimeModelSelector, optional): A component for select the model.
if not provided, an IntimeModelSelector will be configured.
convert_to_fed_event: (ConvertToFedEvent, optional): A component to covert certain events to fed events.
if not provided, a ConvertToFedEvent object will be created.
analytics_receiver (AnlyticsReceiver, optional): Receive analytics.
If not provided, a TBAnalyticsReceiver will be configured.
model_persistor (optional, ModelPersistor): how to persistor the model.
"""
super().__init__(name, min_clients, mandatory_clients)
self.key_metric = key_metric
super().__init__(
name=name,
min_clients=min_clients,
mandatory_clients=mandatory_clients,
)

self.initial_model = initial_model
self.comp_ids = {}

component = ValidationJsonGenerator()
self.to_server(id="json_generator", obj=component)
if validation_json_generator:
validate_object_for_job("validation_json_generator", validation_json_generator, ValidationJsonGenerator)
else:
validation_json_generator = ValidationJsonGenerator()
self.to_server(id="json_generator", obj=validation_json_generator)

if intime_model_selector:
validate_object_for_job("intime_model_selector", intime_model_selector, IntimeModelSelector)
self.to_server(id="model_selector", obj=intime_model_selector)
elif key_metric:
self.to_server(id="model_selector", obj=IntimeModelSelector(key_metric=key_metric))

if convert_to_fed_event:
validate_object_for_job("convert_to_fed_event", convert_to_fed_event, ConvertToFedEvent)
else:
convert_to_fed_event = ConvertToFedEvent(events_to_convert=[ANALYTIC_EVENT_TYPE])
self.convert_to_fed_event = convert_to_fed_event

if self.key_metric:
component = IntimeModelSelector(key_metric=self.key_metric)
self.to_server(id="model_selector", obj=component)
if analytics_receiver:
validate_object_for_job("analytics_receiver", analytics_receiver, AnalyticsReceiver)
else:
analytics_receiver = TBAnalyticsReceiver()

# TODO: make different tracking receivers configurable
component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"])
self.to_server(id="receiver", obj=component)
self.to_server(
id="receiver",
obj=analytics_receiver,
)

if initial_model:
self.comp_ids["persistor_id"] = self.to_server(TFModel(initial_model))
self.comp_ids["persistor_id"] = self.to_server(TFModel(model=initial_model, persistor=model_persistor))

def set_up_client(self, target: str):
component = ConvertToFedEvent(events_to_convert=["analytix_log_stats"], fed_event_prefix="fed.")
self.to(id="event_to_fed", obj=component, target=target)
self.to(id="event_to_fed", obj=self.convert_to_fed_event, target=target)
2 changes: 1 addition & 1 deletion nvflare/app_opt/tf/job_config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, model, persistor: Optional[ModelPersistor] = None):
"""
self.model = model

if self.persistor:
if persistor:
validate_object_for_job("persistor", persistor, ModelPersistor)
self.persistor = persistor

Expand Down

0 comments on commit f0c6e80

Please sign in to comment.