diff --git a/examples/hello-world/step-by-step/cifar10/cse/cse.ipynb b/examples/hello-world/step-by-step/cifar10/cse/cse.ipynb index 4027ccda4a..3c4866fcf0 100644 --- a/examples/hello-world/step-by-step/cifar10/cse/cse.ipynb +++ b/examples/hello-world/step-by-step/cifar10/cse/cse.ipynb @@ -128,7 +128,7 @@ "from nvflare.app_common.workflows.fedavg import FedAvg\n", "from nvflare.app_common.workflows.cross_site_eval import CrossSiteEval\n", "from nvflare.app_opt.pt.job_config.model import PTModel\n", - "from nvflare.job_config.script_runner import FrameworkType, ScriptRunner\n", + "from nvflare.job_config.script_runner import ScriptRunner\n", "\n", "\n", "if __name__ == \"__main__\":\n", diff --git a/nvflare/app_common/abstract/model_persistor.py b/nvflare/app_common/abstract/model_persistor.py index 31e1b08dfe..5b65088ae0 100644 --- a/nvflare/app_common/abstract/model_persistor.py +++ b/nvflare/app_common/abstract/model_persistor.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import Dict from nvflare.apis.fl_context import FLContext from nvflare.app_common.model_desc import ModelDescriptor @@ -56,7 +57,7 @@ def save(self, learnable: ModelLearnable, fl_ctx: FLContext): if self.filter_id: _filter.process_post_save(learnable=learnable, fl_ctx=fl_ctx) - def get(self, model_file, fl_ctx: FLContext) -> object: + def get(self, model_file: str, fl_ctx: FLContext) -> object: learnable = self.get_model(model_file, fl_ctx) if self.filter_id: @@ -90,13 +91,14 @@ def save_model(self, model: ModelLearnable, fl_ctx: FLContext): """ pass - def get_model_inventory(self, fl_ctx: FLContext) -> {str: ModelDescriptor}: + def get_model_inventory(self, fl_ctx: FLContext) -> Dict[str, ModelDescriptor]: """Get the model inventory of the ModelPersister. Args: fl_ctx: FLContext - Returns: { model_kind: ModelDescriptor } + Returns: + A dict of model_name: ModelDescriptor """ pass diff --git a/nvflare/app_common/workflows/base_model_controller.py b/nvflare/app_common/workflows/base_model_controller.py index 4ff07564b1..f40ff6f9a9 100644 --- a/nvflare/app_common/workflows/base_model_controller.py +++ b/nvflare/app_common/workflows/base_model_controller.py @@ -83,7 +83,7 @@ def start_controller(self, fl_ctx: FLContext) -> None: self.persistor = self.engine.get_component(self._persistor_id) if not isinstance(self.persistor, LearnablePersistor): self.warning( - f"Model Persistor {self._persistor_id} must be a LearnablePersistor type object, " + f"Persistor {self._persistor_id} must be a LearnablePersistor type object, " f"but got {type(self.persistor)}" ) self.persistor = None diff --git a/nvflare/app_common/workflows/cross_site_eval.py b/nvflare/app_common/workflows/cross_site_eval.py index 395b20e985..b6d0e3c5af 100644 --- a/nvflare/app_common/workflows/cross_site_eval.py +++ b/nvflare/app_common/workflows/cross_site_eval.py @@ -18,6 +18,7 @@ import time from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.abstract.model_persistor import ModelPersistor from nvflare.app_common.app_constant import AppConstants, DefaultCheckpointFileName, ModelName from nvflare.app_common.utils.fl_model_utils import FLModelUtils from nvflare.fuel.utils import fobs @@ -104,20 +105,24 @@ def run(self) -> None: callback=self._receive_local_model_cb, ) + if self.persistor and not isinstance(self.persistor, ModelPersistor): + self.warning( + f"Model Persistor {self._persistor_id} must be a ModelPersistor type object, " + f"but got {type(self.persistor)}" + ) + self.persistor = None + # Obtain server models and send to clients for validation - if self.persistor: - for server_model_name in self._server_models: - server_model_path = os.path.join(self.get_app_dir(), server_model_name) - server_model_learnable = self.persistor.get_model_from_location(server_model_path, self.fl_ctx) - server_model = FLModelUtils.from_model_learnable(server_model_learnable) - self._send_validation_task(server_model_name, server_model) - else: - for server_model_name in self._server_models: - try: + for server_model_name in self._server_models: + try: + if self.persistor: + server_model_learnable = self.persistor.get_model(server_model_name, self.fl_ctx) + server_model = FLModelUtils.from_model_learnable(server_model_learnable) + else: server_model = fobs.loadf(server_model_name) - self._send_validation_task(server_model_name, server_model) - except Exception as e: - self.exception(f"Unable to load server model {server_model_name}: {e}") + except Exception as e: + self.exception(f"Unable to load server model {server_model_name}: {e}") + self._send_validation_task(server_model_name, server_model) # Wait for all standing tasks to complete, since we used non-blocking `send_model()` while self.get_num_standing_tasks(): @@ -128,6 +133,7 @@ def run(self) -> None: time.sleep(self._task_check_period) self.save_results() + self.info("Stop Cross-Site Evaluation.") def _receive_local_model_cb(self, model: FLModel): client_name = model.meta["client_name"] @@ -193,5 +199,6 @@ def save_results(self): os.makedirs(cross_val_res_dir) res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) + self.info(f"saving validation result {self._json_val_results} to {res_file_path}") with open(res_file_path, "w") as f: - json.dump(self._json_val_results, f) + f.write(json.dumps(self._json_val_results, indent=2)) diff --git a/nvflare/app_opt/pt/file_model_persistor.py b/nvflare/app_opt/pt/file_model_persistor.py index bcb3830263..249e117013 100644 --- a/nvflare/app_opt/pt/file_model_persistor.py +++ b/nvflare/app_opt/pt/file_model_persistor.py @@ -259,11 +259,10 @@ def get_model(self, model_file: str, fl_ctx: FLContext) -> ModelLearnable: return None location = desc.location - return self.get_model_from_location(location, fl_ctx) + return self._get_model_from_location(location, fl_ctx) - def get_model_from_location(self, location, fl_ctx): + def _get_model_from_location(self, location, fl_ctx): try: - # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Use the "cpu" to load the global model weights, avoid GPU out of memory device = "cpu" data = torch.load(location, map_location=device) @@ -271,7 +270,7 @@ def get_model_from_location(self, location, fl_ctx): return persistence_manager.to_model_learnable(self.exclude_vars) except Exception: self.log_exception(fl_ctx, "error loading checkpoint from {}".format(location)) - return {} + return None def get_model_inventory(self, fl_ctx: FLContext) -> Dict[str, ModelDescriptor]: model_inventory = {}