Skip to content

Commit

Permalink
Fix import issues (NVIDIA#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacYangSLA authored Mar 31, 2022
1 parent a9dd30e commit 6416843
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 22 deletions.
4 changes: 2 additions & 2 deletions nvflare/apis/impl/study_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import datetime
import json
import tempfile
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple

from nvflare.apis.study_manager_spec import Study, StudyManagerSpec
from nvflare.apis.storage import StorageSpec
from nvflare.apis.study_manager_spec import Study, StudyManagerSpec


def custom_json_encoder(obj):
Expand Down
2 changes: 1 addition & 1 deletion nvflare/apis/study_manager_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Dict, List
from datetime import datetime
from typing import List

from .fl_context import FLContext

Expand Down
21 changes: 14 additions & 7 deletions nvflare/private/fed/server/fed_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import nvflare.private.fed.protos.federated_pb2 as fed_msg
import nvflare.private.fed.protos.federated_pb2_grpc as fed_service
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, MachineStatus, SnapshotKey, ServerCommandNames, ServerCommandKey
from nvflare.apis.fl_constant import FLContextKey, MachineStatus, ServerCommandKey, ServerCommandNames, SnapshotKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReservedHeaderKey, ReturnCode, Shareable, make_reply
from nvflare.apis.workspace import Workspace
Expand All @@ -44,6 +44,7 @@
from nvflare.private.fed.utils.messageproto import message_to_proto, proto_to_message
from nvflare.private.fed.utils.numproto import proto_to_bytes
from nvflare.widgets.fed_event import ServerFedEventRunner

from .client_manager import ClientManager
from .run_manager import RunManager
from .server_engine import ServerEngine
Expand Down Expand Up @@ -438,8 +439,10 @@ def _process_task_request(self, client, fl_ctx, shared_fl_ctx):
command_shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx)
command_shareable.set_header(ServerCommandKey.FL_CLIENT, client)

data = {ServerCommandKey.COMMAND: ServerCommandNames.GET_TASK,
ServerCommandKey.DATA: command_shareable}
data = {
ServerCommandKey.COMMAND: ServerCommandNames.GET_TASK,
ServerCommandKey.DATA: command_shareable,
}
self.engine.command_conn.send(data)

return_data = self.engine.command_conn.recv()
Expand Down Expand Up @@ -530,8 +533,10 @@ def _submit_update(self, client, contribution_task_name, shareable, shared_fl_co
command_shareable.set_header(ServerCommandKey.TASK_ID, task_id)
command_shareable.set_header(ServerCommandKey.SHAREABLE, shareable)

data = {ServerCommandKey.COMMAND: ServerCommandNames.SUBMIT_UPDATE,
ServerCommandKey.DATA: command_shareable}
data = {
ServerCommandKey.COMMAND: ServerCommandNames.SUBMIT_UPDATE,
ServerCommandKey.DATA: command_shareable,
}
self.engine.command_conn.send(data)
except BaseException:
self.logger.info("Could not connect to server runner process - asked client to end the run")
Expand Down Expand Up @@ -591,8 +596,10 @@ def _aux_communicate(self, fl_ctx, shareable, shared_fl_context, topic):
command_shareable.set_header(ServerCommandKey.TOPIC, topic)
command_shareable.set_header(ServerCommandKey.SHAREABLE, shareable)

data = {ServerCommandKey.COMMAND: ServerCommandNames.AUX_COMMUNICATE,
ServerCommandKey.DATA: command_shareable}
data = {
ServerCommandKey.COMMAND: ServerCommandNames.AUX_COMMUNICATE,
ServerCommandKey.DATA: command_shareable,
}
self.engine.command_conn.send(data)

return_data = self.engine.command_conn.recv()
Expand Down
1 change: 1 addition & 0 deletions nvflare/private/fed/server/server_command_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from multiprocessing.connection import Listener

from nvflare.apis.fl_constant import ServerCommandKey

from .server_commands import ServerCommands


Expand Down
8 changes: 4 additions & 4 deletions nvflare/private/fed/server/server_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import copy
import time

from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, ServerCommandNames, ServerCommandKey
from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, ServerCommandKey, ServerCommandNames
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.utils.fl_context_utils import get_serializable_data
Expand Down Expand Up @@ -71,7 +71,7 @@ def process(self, data: Shareable, fl_ctx: FLContext):
server_runner = fl_ctx.get_prop(FLContextKey.RUNNER)
server_runner.abort(fl_ctx)
# wait for the runner process gracefully abort the run.
time.sleep(3.)
time.sleep(3.0)
return "Aborted the run"


Expand Down Expand Up @@ -131,7 +131,7 @@ def process(self, data: Shareable, fl_ctx: FLContext):
ServerCommandKey.TASK_NAME: taskname,
ServerCommandKey.TASK_ID: task_id,
ServerCommandKey.SHAREABLE: shareable,
ServerCommandKey.FL_CONTEXT: copy.deepcopy(get_serializable_data(fl_ctx).props)
ServerCommandKey.FL_CONTEXT: copy.deepcopy(get_serializable_data(fl_ctx).props),
}
return data

Expand Down Expand Up @@ -202,7 +202,7 @@ def process(self, data: Shareable, fl_ctx: FLContext):

data = {
ServerCommandKey.AUX_REPLY: reply,
ServerCommandKey.FL_CONTEXT: copy.deepcopy(get_serializable_data(fl_ctx).props)
ServerCommandKey.FL_CONTEXT: copy.deepcopy(get_serializable_data(fl_ctx).props),
}
return data

Expand Down
33 changes: 25 additions & 8 deletions nvflare/private/fed/server/server_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,22 @@
from typing import List, Tuple

from nvflare.apis.client import Client
from nvflare.apis.fl_constant import FLContextKey, MachineStatus, ReservedTopic, ReturnCode, SnapshotKey, \
AdminCommandNames, ServerCommandNames, ServerCommandKey
from nvflare.apis.fl_constant import (
AdminCommandNames,
FLContextKey,
MachineStatus,
ReservedTopic,
ReturnCode,
ServerCommandKey,
ServerCommandNames,
SnapshotKey,
)
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_snapshot import FLSnapshot
from nvflare.apis.impl.job_def_manager import SimpleJobDefManager
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.utils.common_utils import get_open_ports
from nvflare.apis.study_manager_spec import StudyManagerSpec
from nvflare.apis.utils.common_utils import get_open_ports
from nvflare.apis.utils.fl_context_utils import get_serializable_data
from nvflare.apis.workspace import Workspace
from nvflare.app_common.storages.filesystem_storage import FilesystemStorage
Expand All @@ -45,6 +53,7 @@
from nvflare.private.fed.server.server_json_config import ServerJsonConfigurator
from nvflare.widgets.info_collector import InfoCollector
from nvflare.widgets.widget import Widget, WidgetID

from .client_manager import ClientManager
from .run_manager import RunManager
from .server_engine_internal_spec import EngineInfo, RunInfo, ServerEngineInternalSpec
Expand Down Expand Up @@ -168,7 +177,9 @@ def start_app_on_server(self, snapshot=None) -> str:
app_custom_folder = ""
if self.server.enable_byoc:
app_custom_folder = os.path.join(app_root, "custom")
self.child_process = self._start_runner_process(self.args, app_root, self.run_number, app_custom_folder, snapshot)
self.child_process = self._start_runner_process(
self.args, app_root, self.run_number, app_custom_folder, snapshot
)

threading.Thread(target=self.wait_for_complete).start()

Expand Down Expand Up @@ -211,10 +222,16 @@ def _start_runner_process(self, args, app_root, run_number, app_custom_folder, s
command = (
f"{sys.executable} -m nvflare.private.fed.app.server.runner_process -m "
+ args.workspace
+ " -s fed_server.json -r " + app_root
+ " -n " + str(run_number)
+ " -p " + str(self.open_port)
+ " --set" + command_options + " print_conf=True restore_snapshot=" + str(restore_snapshot)
+ " -s fed_server.json -r "
+ app_root
+ " -n "
+ str(run_number)
+ " -p "
+ str(self.open_port)
+ " --set"
+ command_options
+ " print_conf=True restore_snapshot="
+ str(restore_snapshot)
)
# use os.setsid to create new process group ID

Expand Down

0 comments on commit 6416843

Please sign in to comment.