Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable simulator to run HE #2339

Merged
merged 11 commits into from
Apr 10, 2024
4 changes: 3 additions & 1 deletion nvflare/private/fed/app/simulator/simulator_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from nvflare.fuel.utils.zip_utils import split_path, unzip_all_from_bytes, zip_directory_to_bytes
from nvflare.private.defs import AppFolderConstants
from nvflare.private.fed.app.deployer.simulator_deployer import SimulatorDeployer
from nvflare.private.fed.app.utils import kill_child_processes
from nvflare.private.fed.app.utils import init_security_content_service, kill_child_processes
from nvflare.private.fed.client.client_status import ClientStatus
from nvflare.private.fed.server.job_meta_validator import JobMetaValidator
from nvflare.private.fed.simulator.simulator_app_runner import SimulatorServerAppRunner
Expand Down Expand Up @@ -153,6 +153,8 @@ def setup(self):
AuthorizationService.initialize(EmptyAuthorizer())
AuditService.the_auditor = SimulatorAuditor()

init_security_content_service(self.args.workspace)

self.simulator_root = os.path.join(self.args.workspace, SimulatorConstants.JOB_NAME)
if os.path.exists(self.simulator_root):
shutil.rmtree(self.simulator_root)
Expand Down
4 changes: 3 additions & 1 deletion nvflare/private/fed/app/simulator/simulator_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from nvflare.fuel.hci.server.authz import AuthorizationService
from nvflare.fuel.sec.audit import AuditService
from nvflare.private.fed.app.deployer.base_client_deployer import BaseClientDeployer
from nvflare.private.fed.app.utils import check_parent_alive
from nvflare.private.fed.app.utils import check_parent_alive, init_security_content_service
from nvflare.private.fed.client.client_engine import ClientEngine
from nvflare.private.fed.client.client_status import ClientStatus
from nvflare.private.fed.client.fed_client import FederatedClient
Expand Down Expand Up @@ -241,6 +241,8 @@ def main(args):
# AuditService.initialize(audit_file_name=WorkspaceConstants.AUDIT_LOG)
AuditService.the_auditor = SimulatorAuditor()

init_security_content_service(args.workspace)

if args.gpu:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
Expand Down
10 changes: 9 additions & 1 deletion nvflare/private/fed/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import psutil

from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_constant import FLContextKey, WorkspaceConstants
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_exception import UnsafeComponentError
from nvflare.apis.workspace import Workspace
from nvflare.fuel.hci.security import hash_password
from nvflare.fuel.sec.security_content_service import SecurityContentService
from nvflare.private.defs import SSLConstants
from nvflare.private.fed.runner import Runner
from nvflare.private.fed.server.admin import FedAdminServer
Expand Down Expand Up @@ -103,6 +105,12 @@ def version_check():
raise RuntimeError("Python versions 3.7 and below are not supported. Please use Python 3.8, 3.9 or 3.10")


def init_security_content_service(workspace_dir):
os.makedirs(os.path.join(workspace_dir, WorkspaceConstants.STARTUP_FOLDER_NAME), exist_ok=True)
workspace_obj = Workspace(root_dir=workspace_dir)
SecurityContentService.initialize(content_folder=workspace_obj.get_startup_kit_dir())


def component_security_check(fl_ctx: FLContext):
exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS)
if exceptions:
Expand Down
20 changes: 14 additions & 6 deletions tests/unit_test/private/fed/app/simulator/simulator_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import os
import shutil
import uuid
from unittest.mock import patch

import pytest

from nvflare.apis.fl_constant import WorkspaceConstants
from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner
from nvflare.private.fed.utils.fed_utils import split_gpus

Expand All @@ -28,14 +30,22 @@ def get_root_url_for_child(self):


class TestSimulatorRunner:
def setup_method(self, method):
self.workspace_name = str(uuid.uuid4())
self.cwd = os.getcwd()
os.makedirs(os.path.join(self.cwd, self.workspace_name, WorkspaceConstants.STARTUP_FOLDER_NAME))

def teardown_method(self, method):
os.chdir(self.cwd)
shutil.rmtree(os.path.join(self.cwd, self.workspace_name))

@patch("nvflare.private.fed.app.deployer.simulator_deployer.SimulatorServer.deploy")
@patch("nvflare.private.fed.app.utils.FedAdminServer")
@patch("nvflare.private.fed.client.fed_client.FederatedClient.register")
@patch("nvflare.private.fed.server.fed_server.BaseServer.get_cell", return_value=MockCell())
def test_valid_job_simulate_setup(self, mock_deploy, mock_admin, mock_register, mock_cell):
workspace_name = str(uuid.uuid4())
job_folder = os.path.join(os.path.dirname(__file__), "../../../../data/jobs/valid_job")
runner = SimulatorRunner(job_folder=job_folder, workspace=workspace_name, threads=1)
runner = SimulatorRunner(job_folder=job_folder, workspace=self.workspace_name, threads=1)
assert runner.setup()

expected_clients = ["site-1", "site-2"]
Expand All @@ -49,9 +59,8 @@ def test_valid_job_simulate_setup(self, mock_deploy, mock_admin, mock_register,
@patch("nvflare.private.fed.client.fed_client.FederatedClient.register")
@patch("nvflare.private.fed.server.fed_server.BaseServer.get_cell", return_value=MockCell())
def test_client_names_setup(self, mock_deploy, mock_admin, mock_register, mock_cell):
workspace_name = str(uuid.uuid4())
job_folder = os.path.join(os.path.dirname(__file__), "../../../../data/jobs/valid_job")
runner = SimulatorRunner(job_folder=job_folder, workspace=workspace_name, clients="site-1", threads=1)
runner = SimulatorRunner(job_folder=job_folder, workspace=self.workspace_name, clients="site-1", threads=1)
assert runner.setup()

expected_clients = ["site-1"]
Expand All @@ -65,9 +74,8 @@ def test_client_names_setup(self, mock_deploy, mock_admin, mock_register, mock_c
@patch("nvflare.private.fed.client.fed_client.FederatedClient.register")
@patch("nvflare.private.fed.server.fed_server.BaseServer.get_cell", return_value=MockCell())
def test_no_app_for_client(self, mock_deploy, mock_admin, mock_register, mock_cell):
workspace_name = str(uuid.uuid4())
job_folder = os.path.join(os.path.dirname(__file__), "../../../../data/jobs/valid_job")
runner = SimulatorRunner(job_folder=job_folder, workspace=workspace_name, n_clients=3, threads=1)
runner = SimulatorRunner(job_folder=job_folder, workspace=self.workspace_name, n_clients=3, threads=1)
assert not runner.setup()

@pytest.mark.parametrize(
Expand Down
Loading