Skip to content

Commit

Permalink
Add CLI subcommand union and alias support (#380)
Browse files Browse the repository at this point in the history
Co-authored-by: Kyle Schwab <kschwab@micron.com>
  • Loading branch information
kschwab and Kyle Schwab committed Sep 9, 2024
1 parent 12d85cf commit a9eb22e
Show file tree
Hide file tree
Showing 3 changed files with 413 additions and 80 deletions.
68 changes: 64 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ not required, set the `is_required` flag to `False` to disable raising an error
subcommands](https://docs.python.org/3/library/argparse.html#sub-commands).

!!! note
`CliSubCommand` and `CliPositionalArg` are always case sensitive and do not support aliases.
`CliSubCommand` and `CliPositionalArg` are always case sensitive.

```py
import sys
Expand Down Expand Up @@ -817,6 +817,64 @@ assert get_subcommand(cmd).model_dump() == {
}
```

The `CliSubCommand` and `CliPositionalArg` annotations also support union operations and aliases. For unions of Pydantic
models, it is important to remember the [nuances](https://docs.pydantic.dev/latest/concepts/unions/) that can arise
during validation. Specifically, for unions of subcommands that are identical in content, it is recommended to break
them out into separate `CliSubCommand` fields to avoid any complications. Lastly, the derived subcommand names from
unions will be the names of the Pydantic model classes themselves.

When assigning aliases to `CliSubCommand` or `CliPositionalArg` fields, only a single alias can be assigned. For
non-union subcommands, aliasing will change the displayed help text and subcommand name. Conversely, for union
subcommands, aliasing will have no tangible effect from the perspective of the CLI settings source. Lastly, for
positional arguments, aliasing will change the CLI help text displayed for the field.

```py
import sys
from typing import Union

from pydantic import BaseModel, Field

from pydantic_settings import (
BaseSettings,
CliPositionalArg,
CliSubCommand,
get_subcommand,
)


class Alpha(BaseModel):
"""Apha Help"""

cmd_alpha: CliPositionalArg[str] = Field(alias='alpha-cmd')


class Beta(BaseModel):
"""Beta Help"""

opt_beta: str = Field(alias='opt-beta')


class Gamma(BaseModel):
"""Gamma Help"""

opt_gamma: str = Field(alias='opt-gamma')


class Root(BaseSettings, cli_parse_args=True, cli_exit_on_error=False):
alpha_or_beta: CliSubCommand[Union[Alpha, Beta]] = Field(alias='alpha-or-beta-cmd')
gamma: CliSubCommand[Gamma] = Field(alias='gamma-cmd')


sys.argv = ['example.py', 'Alpha', 'hello']
assert get_subcommand(Root()).model_dump() == {'cmd_alpha': 'hello'}

sys.argv = ['example.py', 'Beta', '--opt-beta=hey']
assert get_subcommand(Root()).model_dump() == {'opt_beta': 'hey'}

sys.argv = ['example.py', 'gamma-cmd', '--opt-gamma=hi']
assert get_subcommand(Root()).model_dump() == {'opt_gamma': 'hi'}
```

### Customizing the CLI Experience

The below flags can be used to customise the CLI experience to your needs.
Expand Down Expand Up @@ -861,9 +919,11 @@ Additionally, the provided `CliImplicitFlag` and `CliExplicitFlag` annotations c
when necessary.

!!! note
For `python < 3.9`:
* The `--no-flag` option is not generated due to an underlying `argparse` limitation.
* The `CliImplicitFlag` and `CliExplicitFlag` annotations can only be applied to optional bool fields.
For `python < 3.9` the `--no-flag` option is not generated due to an underlying `argparse` limitation.

!!! note
For `python < 3.9` the `CliImplicitFlag` and `CliExplicitFlag` annotations can only be applied to optional boolean
fields.

```py
from pydantic_settings import BaseSettings, CliExplicitFlag, CliImplicitFlag
Expand Down
168 changes: 99 additions & 69 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
if sys.version_info >= (3, 9):
from argparse import BooleanOptionalAction
from argparse import SUPPRESS, ArgumentParser, Namespace, RawDescriptionHelpFormatter, _SubParsersAction
from collections import deque
from collections import defaultdict, deque
from dataclasses import asdict, is_dataclass
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -1239,12 +1239,14 @@ def _load_env_vars(
if isinstance(val, list):
parsed_args[field_name] = self._merge_parsed_list(val, field_name)
elif field_name.endswith(':subcommand') and val is not None:
selected_subcommands.append(field_name.split(':')[0] + val)
subcommand_name = field_name.split(':')[0] + val
subcommand_dest = self._cli_subcommands[field_name][subcommand_name]
selected_subcommands.append(subcommand_dest)

for subcommands in self._cli_subcommands.values():
for subcommand in subcommands:
if subcommand not in selected_subcommands:
parsed_args[subcommand] = self.cli_parse_none_str
for subcommand_dest in subcommands.values():
if subcommand_dest not in selected_subcommands:
parsed_args[subcommand_dest] = self.cli_parse_none_str

parsed_args = {key: val for key, val in parsed_args.items() if not key.endswith(':subcommand')}
if selected_subcommands:
Expand Down Expand Up @@ -1389,26 +1391,26 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F
sub_models.append(type_) # type: ignore
return sub_models

def _get_resolved_names(
def _get_alias_names(
self, field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str]
) -> tuple[tuple[str, ...], bool]:
resolved_names: list[str] = []
alias_names: list[str] = []
is_alias_path_only: bool = True
if not any((field_info.alias, field_info.validation_alias)):
resolved_names += [field_name]
alias_names += [field_name]
is_alias_path_only = False
else:
new_alias_paths: list[AliasPath] = []
for alias in (field_info.alias, field_info.validation_alias):
if alias is None:
continue
elif isinstance(alias, str):
resolved_names.append(alias)
alias_names.append(alias)
is_alias_path_only = False
elif isinstance(alias, AliasChoices):
for name in alias.choices:
if isinstance(name, str):
resolved_names.append(name)
alias_names.append(name)
is_alias_path_only = False
else:
new_alias_paths.append(name)
Expand All @@ -1418,11 +1420,11 @@ def _get_resolved_names(
name = cast(str, alias_path.path[0])
name = name.lower() if not self.case_sensitive else name
alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list'
if not resolved_names and is_alias_path_only:
resolved_names.append(name)
if not alias_names and is_alias_path_only:
alias_names.append(name)
if not self.case_sensitive:
resolved_names = [resolved_name.lower() for resolved_name in resolved_names]
return tuple(dict.fromkeys(resolved_names)), is_alias_path_only
alias_names = [alias_name.lower() for alias_name in alias_names]
return tuple(dict.fromkeys(alias_names)), is_alias_path_only

def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None:
if _CliImplicitFlag in field_info.metadata:
Expand All @@ -1447,22 +1449,24 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
if _CliSubCommand in field_info.metadata:
if not field_info.is_required():
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value')
elif any((field_info.alias, field_info.validation_alias)):
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has an alias')
else:
alias_names, *_ = self._get_alias_names(field_name, field_info, {})
if len(alias_names) > 1:
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases')
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
if len(field_types) != 1:
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple types')
elif not (is_model_class(field_types[0]) or is_pydantic_dataclass(field_types[0])):
raise SettingsError(
f'subcommand argument {model.__name__}.{field_name} is not derived from BaseModel'
)
for field_type in field_types:
if not (is_model_class(field_type) or is_pydantic_dataclass(field_type)):
raise SettingsError(
f'subcommand argument {model.__name__}.{field_name} has type not derived from BaseModel'
)
subcommand_args.append((field_name, field_info))
elif _CliPositionalArg in field_info.metadata:
if not field_info.is_required():
raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value')
elif any((field_info.alias, field_info.validation_alias)):
raise SettingsError(f'positional argument {model.__name__}.{field_name} has an alias')
else:
alias_names, *_ = self._get_alias_names(field_name, field_info, {})
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
positional_args.append((field_name, field_info))
else:
self._verify_cli_flag_annotations(model, field_name, field_info)
Expand Down Expand Up @@ -1529,7 +1533,7 @@ def _connect_root_parser(
self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method')
self._formatter_class = formatter_class
self._cli_dict_args: dict[str, type[Any] | None] = {}
self._cli_subcommands: dict[str, list[str]] = {}
self._cli_subcommands: defaultdict[str, dict[str, str]] = defaultdict(dict)
self._add_parser_args(
parser=self.root_parser,
model=self.settings_cls,
Expand All @@ -1556,64 +1560,93 @@ def _add_parser_args(
alias_path_args: dict[str, str] = {}
for field_name, field_info in self._sort_arg_fields(model):
sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info)
alias_names, is_alias_path_only = self._get_alias_names(field_name, field_info, alias_path_args)
preferred_alias = alias_names[0]
if _CliSubCommand in field_info.metadata:
if subparsers is None:
subparsers = self._add_subparsers(parser, title='subcommands', dest=f'{arg_prefix}:subcommand')
self._cli_subcommands[f'{arg_prefix}:subcommand'] = [f'{arg_prefix}{field_name}']
else:
self._cli_subcommands[f'{arg_prefix}:subcommand'].append(f'{arg_prefix}{field_name}')
if hasattr(subparsers, 'metavar'):
metavar = ','.join(self._cli_subcommands[f'{arg_prefix}:subcommand'])
subparsers.metavar = f'{{{metavar}}}'

model = sub_models[0]
self._add_parser_args(
parser=self._add_parser(
subparsers,
field_name,
help=field_info.description,
formatter_class=self._formatter_class,
description=None if model.__doc__ is None else dedent(model.__doc__),
),
model=model,
added_args=[],
arg_prefix=f'{arg_prefix}{field_name}.',
subcommand_prefix=f'{subcommand_prefix}{field_name}.',
group=None,
alias_prefixes=[],
model_default=PydanticUndefined,
)
for model in sub_models:
subcommand_alias = model.__name__ if len(sub_models) > 1 else preferred_alias
subcommand_name = f'{arg_prefix}{subcommand_alias}'
subcommand_dest = f'{arg_prefix}{preferred_alias}'
self._cli_subcommands[f'{arg_prefix}:subcommand'][subcommand_name] = subcommand_dest

subcommand_help = None if len(sub_models) > 1 else field_info.description
if self.cli_use_class_docs_for_groups:
subcommand_help = None if model.__doc__ is None else dedent(model.__doc__)

subparsers = (
self._add_subparsers(
parser,
title='subcommands',
dest=f'{arg_prefix}:subcommand',
description=field_info.description if len(sub_models) > 1 else None,
)
if subparsers is None
else subparsers
)

if hasattr(subparsers, 'metavar'):
subparsers.metavar = (
f'{subparsers.metavar[:-1]},{subcommand_alias}}}'
if subparsers.metavar
else f'{{{subcommand_alias}}}'
)

self._add_parser_args(
parser=self._add_parser(
subparsers,
subcommand_alias,
help=subcommand_help,
formatter_class=self._formatter_class,
description=None if model.__doc__ is None else dedent(model.__doc__),
),
model=model,
added_args=[],
arg_prefix=f'{arg_prefix}{preferred_alias}.',
subcommand_prefix=f'{subcommand_prefix}{preferred_alias}.',
group=None,
alias_prefixes=[],
model_default=PydanticUndefined,
)
else:
resolved_names, is_alias_path_only = self._get_resolved_names(field_name, field_info, alias_path_args)
arg_flag: str = '--'
is_append_action = _annotation_contains_types(
field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True
)
is_parser_submodel = sub_models and not is_append_action
kwargs: dict[str, Any] = {}
kwargs['default'] = SUPPRESS
kwargs['help'] = self._help_format(field_name, field_info, model_default)
kwargs['dest'] = f'{arg_prefix}{resolved_names[0]}'
kwargs['metavar'] = self._metavar_format(field_info.annotation)
kwargs['required'] = (
self.cli_enforce_required and field_info.is_required() and model_default is PydanticUndefined
)
kwargs['dest'] = (
# Strip prefix if validation alias is set and value is not complex.
# Related https://github.com/pydantic/pydantic-settings/pull/25
f'{arg_prefix}{preferred_alias}'[self.env_prefix_len :]
if arg_prefix and field_info.validation_alias is not None and not is_parser_submodel
else f'{arg_prefix}{preferred_alias}'
)

if kwargs['dest'] in added_args:
continue
if _annotation_contains_types(
field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True
):

if is_append_action:
kwargs['action'] = 'append'
if _annotation_contains_types(field_info.annotation, (dict, Mapping), is_strip_annotated=True):
self._cli_dict_args[kwargs['dest']] = field_info.annotation

arg_names = self._get_arg_names(arg_prefix, subcommand_prefix, alias_prefixes, resolved_names)
arg_names = self._get_arg_names(arg_prefix, subcommand_prefix, alias_prefixes, alias_names)
if _CliPositionalArg in field_info.metadata:
kwargs['metavar'] = resolved_names[0].upper()
kwargs['metavar'] = preferred_alias.upper()
arg_names = [kwargs['dest']]
del kwargs['dest']
del kwargs['required']
arg_flag = ''

self._convert_bool_flag(kwargs, field_info, model_default)

if sub_models and kwargs.get('action') != 'append':
if is_parser_submodel:
self._add_parser_submodels(
parser,
sub_models,
Expand All @@ -1625,14 +1658,10 @@ def _add_parser_args(
kwargs,
field_name,
field_info,
resolved_names,
alias_names,
model_default=model_default,
)
elif not is_alias_path_only:
if arg_prefix and field_info.validation_alias is not None:
# Strip prefix if validation alias is set and value is not complex.
# Related https://github.com/pydantic/pydantic-settings/pull/25
kwargs['dest'] = kwargs['dest'][self.env_prefix_len :]
if group is not None:
if isinstance(group, dict):
group = self._add_argument_group(parser, **group)
Expand Down Expand Up @@ -1662,11 +1691,11 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode
)

def _get_arg_names(
self, arg_prefix: str, subcommand_prefix: str, alias_prefixes: list[str], resolved_names: tuple[str, ...]
self, arg_prefix: str, subcommand_prefix: str, alias_prefixes: list[str], alias_names: tuple[str, ...]
) -> list[str]:
arg_names: list[str] = []
for prefix in [arg_prefix] + alias_prefixes:
for name in resolved_names:
for name in alias_names:
arg_names.append(
f'{prefix}{name}'
if subcommand_prefix == self.env_prefix
Expand All @@ -1686,7 +1715,7 @@ def _add_parser_submodels(
kwargs: dict[str, Any],
field_name: str,
field_info: FieldInfo,
resolved_names: tuple[str, ...],
alias_names: tuple[str, ...],
model_default: Any,
) -> None:
model_group: Any = None
Expand All @@ -1711,6 +1740,7 @@ def _add_parser_submodels(
else:
model_group_kwargs['description'] = desc_header

preferred_alias = alias_names[0]
if not self.cli_avoid_json:
added_args.append(arg_names[0])
kwargs['help'] = f'set {arg_names[0]} from JSON string'
Expand All @@ -1721,10 +1751,10 @@ def _add_parser_submodels(
parser=parser,
model=model,
added_args=added_args,
arg_prefix=f'{arg_prefix}{resolved_names[0]}.',
arg_prefix=f'{arg_prefix}{preferred_alias}.',
subcommand_prefix=subcommand_prefix,
group=model_group if model_group else model_group_kwargs,
alias_prefixes=[f'{arg_prefix}{name}.' for name in resolved_names[1:]],
alias_prefixes=[f'{arg_prefix}{name}.' for name in alias_names[1:]],
model_default=model_default,
)

Expand Down
Loading

0 comments on commit a9eb22e

Please sign in to comment.