diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index c5748fd19..922900170 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3372,11 +3372,15 @@ def arguments_parameter( return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias) +VarKwargsMode: TypeAlias = Literal['uniform', 'unpacked-typed-dict'] + + class ArgumentsSchema(TypedDict, total=False): type: Required[Literal['arguments']] arguments_schema: Required[List[ArgumentsParameter]] populate_by_name: bool var_args_schema: CoreSchema + var_kwargs_mode: VarKwargsMode var_kwargs_schema: CoreSchema ref: str metadata: Dict[str, Any] @@ -3388,6 +3392,7 @@ def arguments_schema( *, populate_by_name: bool | None = None, var_args_schema: CoreSchema | None = None, + var_kwargs_mode: VarKwargsMode | None = None, var_kwargs_schema: CoreSchema | None = None, ref: str | None = None, metadata: Dict[str, Any] | None = None, @@ -3414,6 +3419,9 @@ def arguments_schema( arguments: The arguments to use for the arguments schema populate_by_name: Whether to populate by name var_args_schema: The variable args schema to use for the arguments schema + var_kwargs_mode: The validation mode to use for variadic keyword arguments. If `'uniform'`, every value of the + keyword arguments will be validated against the `var_kwargs_schema` schema. If `'unpacked-typed-dict'`, + the `var_kwargs_schema` argument must be a [`typed_dict_schema`][pydantic_core.core_schema.typed_dict_schema] var_kwargs_schema: The variable kwargs schema to use for the arguments schema ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -3424,6 +3432,7 @@ def arguments_schema( arguments_schema=arguments, populate_by_name=populate_by_name, var_args_schema=var_args_schema, + var_kwargs_mode=var_kwargs_mode, var_kwargs_schema=var_kwargs_schema, ref=ref, metadata=metadata, diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index 075bc009e..22f870c56 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString, PyTuple}; @@ -15,6 +17,27 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +#[derive(Debug, PartialEq)] +enum VarKwargsMode { + Uniform, + UnpackedTypedDict, +} + +impl FromStr for VarKwargsMode { + type Err = PyErr; + + fn from_str(s: &str) -> Result { + match s { + "uniform" => Ok(Self::Uniform), + "unpacked-typed-dict" => Ok(Self::UnpackedTypedDict), + s => py_schema_err!( + "Invalid var_kwargs mode: `{}`, expected `uniform` or `unpacked-typed-dict`", + s + ), + } + } +} + #[derive(Debug)] struct Parameter { positional: bool, @@ -29,6 +52,7 @@ pub struct ArgumentsValidator { parameters: Vec, positional_params_count: usize, var_args_validator: Option>, + var_kwargs_mode: VarKwargsMode, var_kwargs_validator: Option>, loc_by_alias: bool, extra: ExtraBehavior, @@ -117,6 +141,22 @@ impl BuildValidator for ArgumentsValidator { }); } + let py_var_kwargs_mode: Bound = schema + .get_as(intern!(py, "var_kwargs_mode"))? + .unwrap_or_else(|| PyString::new_bound(py, "uniform")); + + let var_kwargs_mode = VarKwargsMode::from_str(py_var_kwargs_mode.to_str()?)?; + let var_kwargs_validator = match schema.get_item(intern!(py, "var_kwargs_schema"))? { + Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)), + None => None, + }; + + if var_kwargs_mode == VarKwargsMode::UnpackedTypedDict && var_kwargs_validator.is_none() { + return py_schema_err!( + "`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`" + ); + } + Ok(Self { parameters, positional_params_count, @@ -124,10 +164,8 @@ impl BuildValidator for ArgumentsValidator { Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)), None => None, }, - var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema"))? { - Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)), - None => None, - }, + var_kwargs_mode, + var_kwargs_validator, loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true), extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?, } @@ -255,6 +293,9 @@ impl Validator for ArgumentsValidator { } } } + + let remaining_kwargs = PyDict::new_bound(py); + // if there are kwargs check any that haven't been processed yet if let Some(kwargs) = args.kwargs() { if kwargs.len() > used_kwargs.len() { @@ -278,26 +319,33 @@ impl Validator for ArgumentsValidator { Err(err) => return Err(err), }; if !used_kwargs.contains(either_str.as_cow()?.as_ref()) { - match self.var_kwargs_validator { - Some(ref validator) => match validator.validate(py, value.borrow_input(), state) { - Ok(value) => { - output_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?; - } - Err(ValError::LineErrors(line_errors)) => { - for err in line_errors { - errors.push(err.with_outer_location(raw_key.clone())); + match self.var_kwargs_mode { + VarKwargsMode::Uniform => match &self.var_kwargs_validator { + Some(validator) => match validator.validate(py, value.borrow_input(), state) { + Ok(value) => { + output_kwargs + .set_item(either_str.as_py_string(py, state.cache_str()), value)?; + } + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + errors.push(err.with_outer_location(raw_key.clone())); + } + } + Err(err) => return Err(err), + }, + None => { + if let ExtraBehavior::Forbid = self.extra { + errors.push(ValLineError::new_with_loc( + ErrorTypeDefaults::UnexpectedKeywordArgument, + value, + raw_key.clone(), + )); } } - Err(err) => return Err(err), }, - None => { - if let ExtraBehavior::Forbid = self.extra { - errors.push(ValLineError::new_with_loc( - ErrorTypeDefaults::UnexpectedKeywordArgument, - value, - raw_key.clone(), - )); - } + VarKwargsMode::UnpackedTypedDict => { + // Save to the remaining kwargs, we will validate as a single dict: + remaining_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?; } } } @@ -305,6 +353,24 @@ impl Validator for ArgumentsValidator { } } + if self.var_kwargs_mode == VarKwargsMode::UnpackedTypedDict { + // `var_kwargs_validator` is guaranteed to be `Some`: + match self + .var_kwargs_validator + .as_ref() + .unwrap() + .validate(py, remaining_kwargs.as_any(), state) + { + Ok(value) => { + output_kwargs.update(value.downcast_bound::(py).unwrap().as_mapping())?; + } + Err(ValError::LineErrors(line_errors)) => { + errors.extend(line_errors); + } + Err(err) => return Err(err), + } + } + if !errors.is_empty() { Err(ValError::LineErrors(errors)) } else { diff --git a/tests/validators/test_arguments.py b/tests/validators/test_arguments.py index 915f05878..1fb17dcee 100644 --- a/tests/validators/test_arguments.py +++ b/tests/validators/test_arguments.py @@ -769,6 +769,19 @@ def test_build_non_default_follows(): ) +def test_build_missing_var_kwargs(): + with pytest.raises( + SchemaError, match="`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`" + ): + SchemaValidator( + { + 'type': 'arguments', + 'arguments_schema': [], + 'var_kwargs_mode': 'unpacked-typed-dict', + } + ) + + @pytest.mark.parametrize( 'input_value,expected', [ @@ -778,7 +791,7 @@ def test_build_non_default_follows(): ], ids=repr, ) -def test_kwargs(py_and_json: PyAndJson, input_value, expected): +def test_kwargs_uniform(py_and_json: PyAndJson, input_value, expected): v = py_and_json( { 'type': 'arguments', @@ -796,6 +809,48 @@ def test_kwargs(py_and_json: PyAndJson, input_value, expected): assert v.validate_test(input_value) == expected +@pytest.mark.parametrize( + 'input_value,expected', + [ + [ArgsKwargs((), {'x': 1}), ((), {'x': 1})], + [ArgsKwargs((), {'x': 1.0}), Err('x\n Input should be a valid integer [type=int_type,')], + [ArgsKwargs((), {'x': 1, 'z': 'str'}), ((), {'x': 1, 'y': 'str'})], + [ArgsKwargs((), {'x': 1, 'y': 'str'}), Err('y\n Extra inputs are not permitted [type=extra_forbidden,')], + ], +) +def test_kwargs_typed_dict(py_and_json: PyAndJson, input_value, expected): + v = py_and_json( + { + 'type': 'arguments', + 'arguments_schema': [], + 'var_kwargs_mode': 'unpacked-typed-dict', + 'var_kwargs_schema': { + 'type': 'typed-dict', + 'fields': { + 'x': { + 'type': 'typed-dict-field', + 'schema': {'type': 'int', 'strict': True}, + 'required': True, + }, + 'y': { + 'type': 'typed-dict-field', + 'schema': {'type': 'str'}, + 'required': False, + 'validation_alias': 'z', + }, + }, + 'config': {'extra_fields_behavior': 'forbid'}, + }, + } + ) + + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_test(input_value) + else: + assert v.validate_test(input_value) == expected + + @pytest.mark.parametrize( 'input_value,expected', [