Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance CLI command config #2716

Merged
merged 3 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions nvflare/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
create_poc_workspace_config,
create_startup_kit_config,
get_hidden_config,
print_hidden_config,
save_config,
)

Expand Down Expand Up @@ -116,14 +117,18 @@ def def_config_parser(sub_cmd):
def handle_config_cmd(args):
config_file_path, nvflare_config = get_hidden_config()

if not args.job_templates_dir or not os.path.isdir(args.job_templates_dir):
raise ValueError(f"job_templates_dir='{args.job_templates_dir}', it is not a directory")
if args.startup_kit_dir is None and args.poc_workspace_dir is None and args.job_templates_dir is None:
print(f"not specifying any directory. print existing config at {config_file_path}")
print_hidden_config(config_file_path, nvflare_config)
return

nvflare_config = create_startup_kit_config(nvflare_config, args.startup_kit_dir)
nvflare_config = create_poc_workspace_config(nvflare_config, args.poc_workspace_dir)
nvflare_config = create_job_template_config(nvflare_config, args.job_templates_dir)

save_config(nvflare_config, config_file_path)
print(f"new config at {config_file_path}")
print_hidden_config(config_file_path, nvflare_config)


def parse_args(prog_name: str):
Expand Down
25 changes: 7 additions & 18 deletions nvflare/tool/poc/poc_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import json
import os
import pathlib
import random
import shutil
import socket
Expand All @@ -36,7 +35,7 @@
from nvflare.lighter.utils import load_yaml, update_project_server_name_config, update_storage_locations
from nvflare.tool.api_utils import shutdown_system
from nvflare.tool.poc.service_constants import FlareServiceConstants as SC
from nvflare.utils.cli_utils import hocon_to_string
from nvflare.utils.cli_utils import get_hidden_nvflare_config_path, get_or_create_hidden_nvflare_dir, hocon_to_string

DEFAULT_WORKSPACE = "/tmp/nvflare/poc"
DEFAULT_PROJECT_NAME = "example_project"
Expand Down Expand Up @@ -396,7 +395,7 @@ def prepare_clients(clients, number_of_clients):


def save_startup_kit_dir_config(workspace, project_name):
dst = get_hidden_nvflare_config_path()
dst = get_or_create_hidden_nvflare_config_path()
config = None
if os.path.isfile(dst):
try:
Expand Down Expand Up @@ -485,27 +484,17 @@ def _prepare_poc(
return True


def get_home_dir():
return Path.home()


def get_hidden_nvflare_config_path() -> str:
def get_or_create_hidden_nvflare_config_path() -> str:
"""
Get the path for the hidden nvflare configuration file.

Returns:
str: The path to the hidden nvflare configuration file.
"""
home_dir = get_home_dir()
hidden_nvflare_dir = pathlib.Path(home_dir) / ".nvflare"

try:
hidden_nvflare_dir.mkdir(exist_ok=True)
except OSError as e:
raise RuntimeError(f"Error creating the hidden nvflare directory: {e}")
hidden_nvflare_dir = get_or_create_hidden_nvflare_dir()

hidden_nvflare_config_file = hidden_nvflare_dir / "config.conf"
return str(hidden_nvflare_config_file)
hidden_nvflare_config_file = get_hidden_nvflare_config_path(str(hidden_nvflare_dir))
return hidden_nvflare_config_file


def prepare_poc_provision(
Expand Down Expand Up @@ -1077,7 +1066,7 @@ def get_poc_workspace():
poc_workspace = os.getenv("NVFLARE_POC_WORKSPACE")

if not poc_workspace:
src_path = get_hidden_nvflare_config_path()
src_path = get_or_create_hidden_nvflare_config_path()
if os.path.isfile(src_path):
from pyhocon import ConfigFactory as CF

Expand Down
18 changes: 13 additions & 5 deletions nvflare/utils/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_hidden_nvflare_config_path(hidden_nvflare_dir: str) -> str:
return str(hidden_nvflare_config_file)


def create_hidden_nvflare_dir():
def get_or_create_hidden_nvflare_dir():
hidden_nvflare_dir = get_hidden_nvflare_dir()
if not hidden_nvflare_dir.exists():
try:
Expand Down Expand Up @@ -70,7 +70,7 @@ def find_startup_kit_location() -> str:


def load_hidden_config() -> ConfigTree:
hidden_dir = create_hidden_nvflare_dir()
hidden_dir = get_or_create_hidden_nvflare_dir()
hidden_nvflare_config_file = get_hidden_nvflare_config_path(str(hidden_dir))
nvflare_config = load_config(hidden_nvflare_config_file)
return nvflare_config
Expand Down Expand Up @@ -139,6 +139,7 @@ def create_job_template_config(nvflare_config: ConfigTree, job_templates_dir: Op
return nvflare_config

job_templates_dir = os.path.abspath(job_templates_dir)
check_dir(job_templates_dir)
conf_str = f"""
job_template {{
path = {job_templates_dir}
Expand Down Expand Up @@ -243,7 +244,7 @@ def save_configs(app_configs: Dict[str, Tuple], keep_origin_format: bool = True)
save_config(dst_config, dst_path, keep_origin_format)


def save_config(dst_config, dst_path, keep_origin_format: bool = True):
def save_config(dst_config: ConfigTree, dst_path, keep_origin_format: bool = True):
if dst_path is None or dst_path.rindex(".") == -1:
raise ValueError(f"configuration file path '{dst_path}' can't be None or has no extension")

Expand Down Expand Up @@ -274,13 +275,20 @@ def save_config(dst_config, dst_path, keep_origin_format: bool = True):
os.remove(dst_path)


def get_hidden_config():
hidden_nvflare_config_file = get_hidden_nvflare_config_path(str(create_hidden_nvflare_dir()))
def get_hidden_config() -> (str, ConfigTree):
hidden_nvflare_config_file = get_hidden_nvflare_config_path(str(get_or_create_hidden_nvflare_dir()))
conf = load_hidden_config()
nvflare_config = CF.parse_string("{}") if not conf else conf
return hidden_nvflare_config_file, nvflare_config


def print_hidden_config(dst_path: str, dst_config: ConfigTree):
original_ext = os.path.basename(dst_path).split(".")[1]
fmt = ConfigFormat.config_ext_formats().get(f".{original_ext}", None)
config_str = hocon_to_string(fmt, dst_config)
print(config_str)


def find_in_list(arr: List, item) -> bool:
if arr is None:
return False
Expand Down
Loading