Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CrossSiteEval #2886

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/hello-world/step-by-step/cifar10/cse/cse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
"from nvflare.app_common.workflows.fedavg import FedAvg\n",
"from nvflare.app_common.workflows.cross_site_eval import CrossSiteEval\n",
"from nvflare.app_opt.pt.job_config.model import PTModel\n",
"from nvflare.job_config.script_runner import FrameworkType, ScriptRunner\n",
"from nvflare.job_config.script_runner import ScriptRunner\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
Expand Down
8 changes: 5 additions & 3 deletions nvflare/app_common/abstract/model_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Dict

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.model_desc import ModelDescriptor
Expand Down Expand Up @@ -56,7 +57,7 @@ def save(self, learnable: ModelLearnable, fl_ctx: FLContext):
if self.filter_id:
_filter.process_post_save(learnable=learnable, fl_ctx=fl_ctx)

def get(self, model_file, fl_ctx: FLContext) -> object:
def get(self, model_file: str, fl_ctx: FLContext) -> object:
learnable = self.get_model(model_file, fl_ctx)

if self.filter_id:
Expand Down Expand Up @@ -90,13 +91,14 @@ def save_model(self, model: ModelLearnable, fl_ctx: FLContext):
"""
pass

def get_model_inventory(self, fl_ctx: FLContext) -> {str: ModelDescriptor}:
def get_model_inventory(self, fl_ctx: FLContext) -> Dict[str, ModelDescriptor]:
"""Get the model inventory of the ModelPersister.

Args:
fl_ctx: FLContext

Returns: { model_kind: ModelDescriptor }
Returns:
A dict of model_name: ModelDescriptor

"""
pass
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_common/workflows/base_model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def start_controller(self, fl_ctx: FLContext) -> None:
self.persistor = self.engine.get_component(self._persistor_id)
if not isinstance(self.persistor, LearnablePersistor):
self.warning(
f"Model Persistor {self._persistor_id} must be a LearnablePersistor type object, "
f"Persistor {self._persistor_id} must be a LearnablePersistor type object, "
f"but got {type(self.persistor)}"
)
self.persistor = None
Expand Down
33 changes: 20 additions & 13 deletions nvflare/app_common/workflows/cross_site_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import time

from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.app_constant import AppConstants, DefaultCheckpointFileName, ModelName
from nvflare.app_common.utils.fl_model_utils import FLModelUtils
from nvflare.fuel.utils import fobs
Expand Down Expand Up @@ -104,20 +105,24 @@ def run(self) -> None:
callback=self._receive_local_model_cb,
)

if self.persistor and not isinstance(self.persistor, ModelPersistor):
self.warning(
f"Model Persistor {self._persistor_id} must be a ModelPersistor type object, "
f"but got {type(self.persistor)}"
)
self.persistor = None

# Obtain server models and send to clients for validation
if self.persistor:
for server_model_name in self._server_models:
server_model_path = os.path.join(self.get_app_dir(), server_model_name)
server_model_learnable = self.persistor.get_model_from_location(server_model_path, self.fl_ctx)
server_model = FLModelUtils.from_model_learnable(server_model_learnable)
self._send_validation_task(server_model_name, server_model)
else:
for server_model_name in self._server_models:
try:
for server_model_name in self._server_models:
try:
if self.persistor:
server_model_learnable = self.persistor.get_model(server_model_name, self.fl_ctx)
server_model = FLModelUtils.from_model_learnable(server_model_learnable)
else:
server_model = fobs.loadf(server_model_name)
self._send_validation_task(server_model_name, server_model)
except Exception as e:
self.exception(f"Unable to load server model {server_model_name}: {e}")
except Exception as e:
self.exception(f"Unable to load server model {server_model_name}: {e}")
self._send_validation_task(server_model_name, server_model)

# Wait for all standing tasks to complete, since we used non-blocking `send_model()`
while self.get_num_standing_tasks():
Expand All @@ -128,6 +133,7 @@ def run(self) -> None:
time.sleep(self._task_check_period)

self.save_results()
self.info("Stop Cross-Site Evaluation.")

def _receive_local_model_cb(self, model: FLModel):
client_name = model.meta["client_name"]
Expand Down Expand Up @@ -193,5 +199,6 @@ def save_results(self):
os.makedirs(cross_val_res_dir)

res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
self.info(f"saving validation result {self._json_val_results} to {res_file_path}")
with open(res_file_path, "w") as f:
json.dump(self._json_val_results, f)
f.write(json.dumps(self._json_val_results, indent=2))
7 changes: 3 additions & 4 deletions nvflare/app_opt/pt/file_model_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,19 +259,18 @@ def get_model(self, model_file: str, fl_ctx: FLContext) -> ModelLearnable:
return None

location = desc.location
return self.get_model_from_location(location, fl_ctx)
return self._get_model_from_location(location, fl_ctx)

def get_model_from_location(self, location, fl_ctx):
def _get_model_from_location(self, location, fl_ctx):
try:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Use the "cpu" to load the global model weights, avoid GPU out of memory
device = "cpu"
data = torch.load(location, map_location=device)
persistence_manager = PTModelPersistenceFormatManager(data, default_train_conf=self.default_train_conf)
return persistence_manager.to_model_learnable(self.exclude_vars)
except Exception:
self.log_exception(fl_ctx, "error loading checkpoint from {}".format(location))
return {}
return None

def get_model_inventory(self, fl_ctx: FLContext) -> Dict[str, ModelDescriptor]:
model_inventory = {}
Expand Down
Loading