Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IPC agent and exchanger #2435

Merged
merged 6 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nvflare/apis/dxo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class DataKind(object):
COLLECTION = "COLLECTION" # Dict or List of DXO objects
STATISTICS = "STATISTICS"
PSI = "PSI"
APP_DEFINED = "APP_DEFINED" # data format is app defined


class MetaKey(FLMetaKey):
Expand Down Expand Up @@ -128,7 +129,7 @@ def validate(self) -> str:
if self.data is None:
return "missing data"

if not isinstance(self.data, dict):
if self.data_kind != DataKind.APP_DEFINED and not isinstance(self.data, dict):
return "invalid data: expect dict but got {}".format(type(self.data))

if self.meta is not None and not isinstance(self.meta, dict):
Expand Down
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ReturnCode(object):
UNSAFE_JOB = "UNSAFE_JOB"
SERVER_NOT_READY = "SERVER_NOT_READY"
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
EARLY_TERMINATION = "EARLY_TERMINATION"


class MachineStatus(Enum):
Expand Down
13 changes: 13 additions & 0 deletions nvflare/app_common/app_defined/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
75 changes: 75 additions & 0 deletions nvflare/app_common/app_defined/aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.abstract.model import ModelLearnableKey
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType

from .component_base import ComponentBase


class AppDefinedAggregator(Aggregator, ComponentBase, ABC):
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
Aggregator.__init__(self)
ComponentBase.__init__(self)
self.current_round = None
self.base_model_obj = None

def handle_event(self, event_type, fl_ctx: FLContext):
if event_type == AppEventType.ROUND_STARTED:
self.fl_ctx = fl_ctx
self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
base_model_learnable = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)
if isinstance(base_model_learnable, dict):
self.base_model_obj = base_model_learnable.get(ModelLearnableKey.WEIGHTS)
self.reset()

@abstractmethod
def reset(self):
pass

@abstractmethod
def processing_training_result(self, client_name: str, trained_weights: Any, trained_meta: dict) -> bool:
pass

@abstractmethod
def aggregate_training_result(self) -> (Any, dict):
pass

def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool:
dxo = from_shareable(shareable)
trained_weights = dxo.data
trained_meta = dxo.meta
self.fl_ctx = fl_ctx
peer_ctx = fl_ctx.get_peer_context()
client_name = peer_ctx.get_identity_name()
return self.processing_training_result(client_name, trained_weights, trained_meta)

def aggregate(self, fl_ctx: FLContext) -> Shareable:
self.fl_ctx = fl_ctx
aggregated_result, aggregated_meta = self.aggregate_training_result()
dxo = DXO(
data_kind=DataKind.APP_DEFINED,
data=aggregated_result,
meta=aggregated_meta,
)
self.debug(f"learnable_to_shareable: {dxo.data}")
return dxo.to_shareable()
87 changes: 87 additions & 0 deletions nvflare/app_common/app_defined/component_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.fl_component import FLComponent


class ComponentBase(FLComponent):
def __init__(self):
FLComponent.__init__(self)
self.fl_ctx = None

def debug(self, msg: str):
"""Convenience method for logging an DEBUG message with contextual info

Args:
msg: the message to be logged

Returns:

"""
self.log_debug(self.fl_ctx, msg)
YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved

def info(self, msg: str):
"""Convenience method for logging an INFO message with contextual info

Args:
msg: the message to be logged

Returns:

"""
self.log_info(self.fl_ctx, msg)

def error(self, msg: str):
"""Convenience method for logging an ERROR message with contextual info

Args:
msg: the message to be logged

Returns:

"""
self.log_error(self.fl_ctx, msg)

def warning(self, msg: str):
"""Convenience method for logging a WARNING message with contextual info

Args:
msg: the message to be logged

Returns:

"""
self.log_warning(self.fl_ctx, msg)

def exception(self, msg: str):
"""Convenience method for logging an EXCEPTION message with contextual info

Args:
msg: the message to be logged

Returns:

"""
self.log_exception(self.fl_ctx, msg)

def critical(self, msg: str):
"""Convenience method for logging a CRITICAL message with contextual info

Args:
msg: the message to be logged

Returns:

"""
self.log_critical(self.fl_ctx, msg)
57 changes: 57 additions & 0 deletions nvflare/app_common/app_defined/model_persistor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable
from nvflare.app_common.abstract.model_persistor import ModelPersistor

from .component_base import ComponentBase


class AppDefinedModelPersistor(ModelPersistor, ComponentBase, ABC):
def __init__(self):
ModelPersistor.__init__(self)
ComponentBase.__init__(self)

@abstractmethod
def read_model(self) -> Any:
"""Load model object.

Returns: a model object
"""
pass

@abstractmethod
def write_model(self, model_obj: Any):
"""Save the model object

Args:
model_obj: the model object to be saved

Returns: None

"""
pass
YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved

def load_model(self, fl_ctx: FLContext) -> ModelLearnable:
self.fl_ctx = fl_ctx
model = self.read_model()
return make_model_learnable(weights=model, meta_props={})

def save_model(self, learnable: ModelLearnable, fl_ctx: FLContext):
self.fl_ctx = fl_ctx
self.write_model(learnable.get(ModelLearnableKey.WEIGHTS))
94 changes: 94 additions & 0 deletions nvflare/app_common/app_defined/shareable_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.learnable import Learnable
from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator
from nvflare.app_common.app_constant import AppConstants

from .component_base import ComponentBase


class AppDefinedShareableGenerator(ShareableGenerator, ComponentBase, ABC):
def __init__(self):
ShareableGenerator.__init__(self)
ComponentBase.__init__(self)
self.current_round = None

@abstractmethod
def model_to_trainable(self, model_obj: Any) -> (Any, dict):
"""Convert the model weights and meta to a format that can be sent to clients to do training

Args:
model_obj: model object

Returns: a tuple of (weights, meta)

The returned weights and meta will be for training and serializable
"""
pass

@abstractmethod
def apply_weights_to_model(self, model_obj: Any, weights: Any, meta: dict) -> Any:
"""Apply trained weights and meta to the base model

Args:
model_obj: base model object that weights will be applied to
weights: trained weights
meta: trained meta

Returns: the updated model object

"""
pass
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved

def learnable_to_shareable(self, learnable: Learnable, fl_ctx: FLContext) -> Shareable:
self.fl_ctx = fl_ctx
self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
self.debug(f"{learnable=}")
base_model_obj = learnable.get(ModelLearnableKey.WEIGHTS)
trainable_weights, trainable_meta = self.model_to_trainable(base_model_obj)
self.debug(f"trainable weights: {trainable_weights}")
dxo = DXO(
data_kind=DataKind.APP_DEFINED,
data=trainable_weights,
meta=trainable_meta,
)
self.debug(f"learnable_to_shareable: {dxo.data}")
return dxo.to_shareable()

def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Learnable:
self.fl_ctx = fl_ctx
self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
base_model_learnable = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)

if not base_model_learnable:
self.system_panic(reason="No global base model!", fl_ctx=fl_ctx)
return base_model_learnable

if not isinstance(base_model_learnable, ModelLearnable):
raise ValueError(f"expect global model to be ModelLearnable but got {type(base_model_learnable)}")
base_model_obj = base_model_learnable.get(ModelLearnableKey.WEIGHTS)

dxo = from_shareable(shareable)
trained_weights = dxo.data
trained_meta = dxo.meta
model_obj = self.apply_weights_to_model(model_obj=base_model_obj, weights=trained_weights, meta=trained_meta)
return make_model_learnable(model_obj, {})
Loading
Loading