Skip to content

Commit

Permalink
Add support for CliMutuallyExclusiveGroup. (#473)
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab authored Nov 11, 2024
1 parent 11a817c commit b4ece52
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 19 deletions.
38 changes: 38 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,44 @@ For `BaseModel` and `pydantic.dataclasses.dataclass` types, `CliApp.run` will in
The alias generator for kebab case does not propagate to subcommands or submodels and will have to be manually set
in these cases.

### Mutually Exclusive Groups

CLI mutually exclusive groups can be created by inheriting from the `CliMutuallyExclusiveGroup` class.

!!! note
A `CliMutuallyExclusiveGroup` cannot be used in a union or contain nested models.

```py
from typing import Optional

from pydantic import BaseModel

from pydantic_settings import CliApp, CliMutuallyExclusiveGroup, SettingsError


class Circle(CliMutuallyExclusiveGroup):
radius: Optional[float] = None
diameter: Optional[float] = None
perimeter: Optional[float] = None


class Settings(BaseModel):
circle: Circle


try:
CliApp.run(
Settings,
cli_args=['--circle.radius=1', '--circle.diameter=2'],
cli_exit_on_error=False,
)
except SettingsError as e:
print(e)
"""
error parsing CLI: argument --circle.diameter: not allowed with argument --circle.radius
"""
```

### Customizing the CLI Experience

The below flags can be used to customise the CLI experience to your needs.
Expand Down
2 changes: 2 additions & 0 deletions pydantic_settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AzureKeyVaultSettingsSource,
CliExplicitFlag,
CliImplicitFlag,
CliMutuallyExclusiveGroup,
CliPositionalArg,
CliSettingsSource,
CliSubCommand,
Expand Down Expand Up @@ -34,6 +35,7 @@
'CliPositionalArg',
'CliExplicitFlag',
'CliImplicitFlag',
'CliMutuallyExclusiveGroup',
'InitSettingsSource',
'JsonConfigSettingsSource',
'PyprojectTomlConfigSettingsSource',
Expand Down
51 changes: 45 additions & 6 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def error(self, message: str) -> NoReturn:
super().error(message)


class CliMutuallyExclusiveGroup(BaseModel):
pass


T = TypeVar('T')
CliSubCommand = Annotated[Union[T, None], _CliSubCommand]
CliPositionalArg = Annotated[T, _CliPositionalArg]
Expand Down Expand Up @@ -1483,7 +1487,7 @@ def _connect_parser_method(
if (
parser_method is not None
and self.case_sensitive is False
and method_name == 'parsed_args_method'
and method_name == 'parse_args_method'
and isinstance(self._root_parser, _CliInternalArgParser)
):

Expand Down Expand Up @@ -1515,6 +1519,26 @@ def none_parser_method(*args: Any, **kwargs: Any) -> Any:
else:
return parser_method

def _connect_group_method(self, add_argument_group_method: Callable[..., Any] | None) -> Callable[..., Any]:
add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')

def add_group_method(parser: Any, **kwargs: Any) -> Any:
if not kwargs.pop('_is_cli_mutually_exclusive_group'):
kwargs.pop('required')
return add_argument_group(parser, **kwargs)
else:
main_group_kwargs = {arg: kwargs.pop(arg) for arg in ['title', 'description'] if arg in kwargs}
main_group_kwargs['title'] += ' (mutually exclusive)'
group = add_argument_group(parser, **main_group_kwargs)
if not hasattr(group, 'add_mutually_exclusive_group'):
raise SettingsError(
'cannot connect CLI settings source root parser: '
'group object is missing add_mutually_exclusive_group but is needed for connecting'
)
return group.add_mutually_exclusive_group(**kwargs)

return add_group_method

def _connect_root_parser(
self,
root_parser: T,
Expand All @@ -1531,9 +1555,9 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace:
self._root_parser = root_parser
if parse_args_method is None:
parse_args_method = _parse_known_args if self.cli_ignore_unknown_args else ArgumentParser.parse_args
self._parse_args = self._connect_parser_method(parse_args_method, 'parsed_args_method')
self._parse_args = self._connect_parser_method(parse_args_method, 'parse_args_method')
self._add_argument = self._connect_parser_method(add_argument_method, 'add_argument_method')
self._add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')
self._add_group = self._connect_group_method(add_argument_group_method)
self._add_parser = self._connect_parser_method(add_parser_method, 'add_parser_method')
self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method')
self._formatter_class = formatter_class
Expand Down Expand Up @@ -1665,6 +1689,7 @@ def _add_parser_args(
if is_parser_submodel:
self._add_parser_submodels(
parser,
model,
sub_models,
added_args,
arg_prefix,
Expand All @@ -1680,7 +1705,7 @@ def _add_parser_args(
elif not is_alias_path_only:
if group is not None:
if isinstance(group, dict):
group = self._add_argument_group(parser, **group)
group = self._add_group(parser, **group)
added_args += list(arg_names)
self._add_argument(group, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs)
else:
Expand Down Expand Up @@ -1724,6 +1749,7 @@ def _get_arg_names(
def _add_parser_submodels(
self,
parser: Any,
model: type[BaseModel],
sub_models: list[type[BaseModel]],
added_args: list[str],
arg_prefix: str,
Expand All @@ -1736,10 +1762,23 @@ def _add_parser_submodels(
alias_names: tuple[str, ...],
model_default: Any,
) -> None:
if issubclass(model, CliMutuallyExclusiveGroup):
# Argparse has deprecated "calling add_argument_group() or add_mutually_exclusive_group() on a
# mutually exclusive group" (https://docs.python.org/3/library/argparse.html#mutual-exclusion).
# Since nested models result in a group add, raise an exception for nested models in a mutually
# exclusive group.
raise SettingsError('cannot have nested models in a CliMutuallyExclusiveGroup')

model_group: Any = None
model_group_kwargs: dict[str, Any] = {}
model_group_kwargs['title'] = f'{arg_names[0]} options'
model_group_kwargs['description'] = field_info.description
model_group_kwargs['required'] = kwargs['required']
model_group_kwargs['_is_cli_mutually_exclusive_group'] = any(
issubclass(model, CliMutuallyExclusiveGroup) for model in sub_models
)
if model_group_kwargs['_is_cli_mutually_exclusive_group'] and len(sub_models) > 1:
raise SettingsError('cannot use union with CliMutuallyExclusiveGroup')
if self.cli_use_class_docs_for_groups and len(sub_models) == 1:
model_group_kwargs['description'] = None if sub_models[0].__doc__ is None else dedent(sub_models[0].__doc__)

Expand All @@ -1762,7 +1801,7 @@ def _add_parser_submodels(
if not self.cli_avoid_json:
added_args.append(arg_names[0])
kwargs['help'] = f'set {arg_names[0]} from JSON string'
model_group = self._add_argument_group(parser, **model_group_kwargs)
model_group = self._add_group(parser, **model_group_kwargs)
self._add_argument(model_group, *(f'{flag_prefix}{name}' for name in arg_names), **kwargs)
for model in sub_models:
self._add_parser_args(
Expand All @@ -1788,7 +1827,7 @@ def _add_parser_alias_paths(
if alias_path_args:
context = parser
if group is not None:
context = self._add_argument_group(parser, **group) if isinstance(group, dict) else group
context = self._add_group(parser, **group) if isinstance(group, dict) else group
is_nested_alias_path = arg_prefix.endswith('.')
arg_prefix = arg_prefix[:-1] if is_nested_alias_path else arg_prefix
for name, metavar in alias_path_args.items():
Expand Down
Loading

0 comments on commit b4ece52

Please sign in to comment.