Skip to content

Commit

Permalink
Update CrossSiteEval (#2886)
Browse files Browse the repository at this point in the history
* Update CrossSiteEval

* Update base class

* Undo no-need change

---------

Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com>
  • Loading branch information
YuanTingHsieh and chesterxgchen authored Aug 30, 2024
1 parent fc088c1 commit d708465
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 22 deletions.
2 changes: 1 addition & 1 deletion examples/hello-world/step-by-step/cifar10/cse/cse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
"from nvflare.app_common.workflows.fedavg import FedAvg\n",
"from nvflare.app_common.workflows.cross_site_eval import CrossSiteEval\n",
"from nvflare.app_opt.pt.job_config.model import PTModel\n",
"from nvflare.job_config.script_runner import FrameworkType, ScriptRunner\n",
"from nvflare.job_config.script_runner import ScriptRunner\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
Expand Down
8 changes: 5 additions & 3 deletions nvflare/app_common/abstract/model_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Dict

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.model_desc import ModelDescriptor
Expand Down Expand Up @@ -56,7 +57,7 @@ def save(self, learnable: ModelLearnable, fl_ctx: FLContext):
if self.filter_id:
_filter.process_post_save(learnable=learnable, fl_ctx=fl_ctx)

def get(self, model_file, fl_ctx: FLContext) -> object:
def get(self, model_file: str, fl_ctx: FLContext) -> object:
learnable = self.get_model(model_file, fl_ctx)

if self.filter_id:
Expand Down Expand Up @@ -90,13 +91,14 @@ def save_model(self, model: ModelLearnable, fl_ctx: FLContext):
"""
pass

def get_model_inventory(self, fl_ctx: FLContext) -> {str: ModelDescriptor}:
def get_model_inventory(self, fl_ctx: FLContext) -> Dict[str, ModelDescriptor]:
"""Get the model inventory of the ModelPersister.
Args:
fl_ctx: FLContext
Returns: { model_kind: ModelDescriptor }
Returns:
A dict of model_name: ModelDescriptor
"""
pass
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_common/workflows/base_model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def start_controller(self, fl_ctx: FLContext) -> None:
self.persistor = self.engine.get_component(self._persistor_id)
if not isinstance(self.persistor, LearnablePersistor):
self.warning(
f"Model Persistor {self._persistor_id} must be a LearnablePersistor type object, "
f"Persistor {self._persistor_id} must be a LearnablePersistor type object, "
f"but got {type(self.persistor)}"
)
self.persistor = None
Expand Down
33 changes: 20 additions & 13 deletions nvflare/app_common/workflows/cross_site_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import time

from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.app_constant import AppConstants, DefaultCheckpointFileName, ModelName
from nvflare.app_common.utils.fl_model_utils import FLModelUtils
from nvflare.fuel.utils import fobs
Expand Down Expand Up @@ -104,20 +105,24 @@ def run(self) -> None:
callback=self._receive_local_model_cb,
)

if self.persistor and not isinstance(self.persistor, ModelPersistor):
self.warning(
f"Model Persistor {self._persistor_id} must be a ModelPersistor type object, "
f"but got {type(self.persistor)}"
)
self.persistor = None

# Obtain server models and send to clients for validation
if self.persistor:
for server_model_name in self._server_models:
server_model_path = os.path.join(self.get_app_dir(), server_model_name)
server_model_learnable = self.persistor.get_model_from_location(server_model_path, self.fl_ctx)
server_model = FLModelUtils.from_model_learnable(server_model_learnable)
self._send_validation_task(server_model_name, server_model)
else:
for server_model_name in self._server_models:
try:
for server_model_name in self._server_models:
try:
if self.persistor:
server_model_learnable = self.persistor.get_model(server_model_name, self.fl_ctx)
server_model = FLModelUtils.from_model_learnable(server_model_learnable)
else:
server_model = fobs.loadf(server_model_name)
self._send_validation_task(server_model_name, server_model)
except Exception as e:
self.exception(f"Unable to load server model {server_model_name}: {e}")
except Exception as e:
self.exception(f"Unable to load server model {server_model_name}: {e}")
self._send_validation_task(server_model_name, server_model)

# Wait for all standing tasks to complete, since we used non-blocking `send_model()`
while self.get_num_standing_tasks():
Expand All @@ -128,6 +133,7 @@ def run(self) -> None:
time.sleep(self._task_check_period)

self.save_results()
self.info("Stop Cross-Site Evaluation.")

def _receive_local_model_cb(self, model: FLModel):
client_name = model.meta["client_name"]
Expand Down Expand Up @@ -193,5 +199,6 @@ def save_results(self):
os.makedirs(cross_val_res_dir)

res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
self.info(f"saving validation result {self._json_val_results} to {res_file_path}")
with open(res_file_path, "w") as f:
json.dump(self._json_val_results, f)
f.write(json.dumps(self._json_val_results, indent=2))
7 changes: 3 additions & 4 deletions nvflare/app_opt/pt/file_model_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,19 +259,18 @@ def get_model(self, model_file: str, fl_ctx: FLContext) -> ModelLearnable:
return None

location = desc.location
return self.get_model_from_location(location, fl_ctx)
return self._get_model_from_location(location, fl_ctx)

def get_model_from_location(self, location, fl_ctx):
def _get_model_from_location(self, location, fl_ctx):
try:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Use the "cpu" to load the global model weights, avoid GPU out of memory
device = "cpu"
data = torch.load(location, map_location=device)
persistence_manager = PTModelPersistenceFormatManager(data, default_train_conf=self.default_train_conf)
return persistence_manager.to_model_learnable(self.exclude_vars)
except Exception:
self.log_exception(fl_ctx, "error loading checkpoint from {}".format(location))
return {}
return None

def get_model_inventory(self, fl_ctx: FLContext) -> Dict[str, ModelDescriptor]:
model_inventory = {}
Expand Down

0 comments on commit d708465

Please sign in to comment.