From 15ca1728ae70fb69eaa2fe14eff1374694a8d874 Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Wed, 13 May 2020 22:29:17 +0000 Subject: [PATCH] Add ability to pass `unknown` in parse calls This adds support for passing the `unknown` parameter in two major locations: Parser instantiation, and Parser.parse calls. use_args and use_kwargs are just parse wrappers, and they need to pass it through as well. It also adds support for a class-level default for unknown, `Parser.DEFAULT_UNKNOWN`, which sets `unknown` for any future parser instances. Explicit tweaks to handle this were necessary in asyncparser and PyramidParser, due to odd method signatures. Support is tested in the core tests, but not the various framework tests. Add a 6.2.0 (Unreleased) changelog entry with detail on this change. The changelog states that we will change the DEFAULT_UNKNOWN default in a future major release. Presumably we'll make it `EXCLUDE`, but I'd like to make it location-dependent if feasible, so I didn't commit to anything in the phrasing. --- CHANGELOG.rst | 45 +++++++++++++++++++++ src/webargs/asyncparser.py | 11 ++++- src/webargs/core.py | 22 +++++++++- src/webargs/pyramidparser.py | 4 ++ tests/test_core.py | 78 +++++++++++++++++++++++++++++------- 5 files changed, 142 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f72d2b00..43145b53 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,51 @@ Changelog --------- +6.2.0 (Unreleased) +****************** + +Features: + +* Add a new ``unknown`` parameter to ``Parser.parse``, ``Parser.use_args``, and + ``Parser.use_kwargs``. When set, it will be passed to the ``Schema.load`` + call. If set to ``None`` (the default), no value is passed, so the schema's + ``unknown`` behavior is used. + +This allows usages like + +.. code-block:: python + + import marshmallow as ma + + # marshmallow 3 only, for use of ``unknown`` and ``EXCLUDE`` + @parser.use_kwargs( + {"q1": ma.fields.Int(), "q2": ma.fields.Int()}, location="query", unknown=ma.EXCLUDE + ) + def foo(q1, q2): + ... + +* Add the ability to set defaults for ``unknown`` on either a Parser instance + or Parser class. Set ``Parser.DEFAULT_UNKNOWN`` on a parser class to apply a value + to any new parser instances created from that class, or set ``unknown`` during + ``Parser`` initialization. + +Usages are varied, but include + +.. code-block:: python + + import marshmallow as ma + from webargs.flaskparser import FlaskParser + + parser = FlaskParser(unknown=ma.INCLUDE) + + # as well as... + class MyParser(FlaskParser): + DEFAULT_UNKNOWN = ma.INCLUDE + + + parser = MyParser() + + 6.1.0 (2020-04-05) ****************** diff --git a/src/webargs/asyncparser.py b/src/webargs/asyncparser.py index 1ba77c70..914dd809 100644 --- a/src/webargs/asyncparser.py +++ b/src/webargs/asyncparser.py @@ -9,6 +9,7 @@ from marshmallow.fields import Field import marshmallow as ma +from webargs.compat import MARSHMALLOW_VERSION_INFO from webargs import core Request = typing.TypeVar("Request") @@ -28,6 +29,7 @@ async def parse( req: Request = None, *, location: str = None, + unknown: str = None, validate: Validate = None, error_status_code: typing.Union[int, None] = None, error_headers: typing.Union[typing.Mapping[str, str], None] = None @@ -38,6 +40,10 @@ async def parse( """ req = req if req is not None else self.get_default_request() location = location or self.location + unknown = unknown or self.unknown + load_kwargs = ( + {"unknown": unknown} if MARSHMALLOW_VERSION_INFO[0] >= 3 and unknown else {} + ) if req is None: raise ValueError("Must pass req object") data = None @@ -47,7 +53,7 @@ async def parse( location_data = await self._load_location_data( schema=schema, req=req, location=location ) - result = schema.load(location_data) + result = schema.load(location_data, **load_kwargs) data = result.data if core.MARSHMALLOW_VERSION_INFO[0] < 3 else result self._validate_arguments(data, validators) except ma.exceptions.ValidationError as error: @@ -111,6 +117,7 @@ def use_args( req: typing.Optional[Request] = None, *, location: str = None, + unknown=None, as_kwargs: bool = False, validate: Validate = None, error_status_code: typing.Optional[int] = None, @@ -143,6 +150,7 @@ async def wrapper(*args, **kwargs): argmap, req=req_obj, location=location, + unknown=unknown, validate=validate, error_status_code=error_status_code, error_headers=error_headers, @@ -165,6 +173,7 @@ def wrapper(*args, **kwargs): argmap, req=req_obj, location=location, + unknown=unknown, validate=validate, error_status_code=error_status_code, error_headers=error_headers, diff --git a/src/webargs/core.py b/src/webargs/core.py index d242a931..0dc96a06 100644 --- a/src/webargs/core.py +++ b/src/webargs/core.py @@ -101,11 +101,15 @@ class Parser: etc. :param str location: Default location to use for data + :param str unknown: Default value for ``unknown`` in ``parse``, + ``use_args``, and ``use_kwargs`` :param callable error_handler: Custom error handler function. """ #: Default location to check for data DEFAULT_LOCATION = "json" + #: Default value to use for 'unknown' on schema load + DEFAULT_UNKNOWN = None #: The marshmallow Schema class to use when creating new schemas DEFAULT_SCHEMA_CLASS = ma.Schema #: Default status code to return for validation errors @@ -125,10 +129,13 @@ class Parser: "json_or_form": "load_json_or_form", } - def __init__(self, location=None, *, error_handler=None, schema_class=None): + def __init__( + self, location=None, *, unknown=None, error_handler=None, schema_class=None + ): self.location = location or self.DEFAULT_LOCATION self.error_callback = _callable_or_raise(error_handler) self.schema_class = schema_class or self.DEFAULT_SCHEMA_CLASS + self.unknown = unknown or self.DEFAULT_UNKNOWN def _get_loader(self, location): """Get the loader function for the given location. @@ -222,6 +229,7 @@ def parse( req=None, *, location=None, + unknown=None, validate=None, error_status_code=None, error_headers=None @@ -236,6 +244,8 @@ def parse( Can be any of the values in :py:attr:`~__location_map__`. By default, that means one of ``('json', 'query', 'querystring', 'form', 'headers', 'cookies', 'files', 'json_or_form')``. + :param str unknown: A value to pass for ``unknown`` when calling the + schema's ``load`` method (marshmallow 3 only). :param callable validate: Validation function or list of validation functions that receives the dictionary of parsed arguments. Validator either returns a boolean or raises a :exc:`ValidationError`. @@ -248,6 +258,10 @@ def parse( """ req = req if req is not None else self.get_default_request() location = location or self.location + unknown = unknown or self.unknown + load_kwargs = ( + {"unknown": unknown} if MARSHMALLOW_VERSION_INFO[0] >= 3 and unknown else {} + ) if req is None: raise ValueError("Must pass req object") data = None @@ -257,7 +271,7 @@ def parse( location_data = self._load_location_data( schema=schema, req=req, location=location ) - result = schema.load(location_data) + result = schema.load(location_data, **load_kwargs) data = result.data if MARSHMALLOW_VERSION_INFO[0] < 3 else result self._validate_arguments(data, validators) except ma.exceptions.ValidationError as error: @@ -307,6 +321,7 @@ def use_args( req=None, *, location=None, + unknown=None, as_kwargs=False, validate=None, error_status_code=None, @@ -325,6 +340,8 @@ def greet(args): of argname -> `marshmallow.fields.Field` pairs, or a callable which accepts a request and returns a `marshmallow.Schema`. :param str location: Where on the request to load values. + :param str unknown: A value to pass for ``unknown`` when calling the + schema's ``load`` method (marshmallow 3 only). :param bool as_kwargs: Whether to insert arguments as keyword arguments. :param callable validate: Validation function that receives the dictionary of parsed arguments. If the function returns ``False``, the parser @@ -356,6 +373,7 @@ def wrapper(*args, **kwargs): argmap, req=req_obj, location=location, + unknown=unknown, validate=validate, error_status_code=error_status_code, error_headers=error_headers, diff --git a/src/webargs/pyramidparser.py b/src/webargs/pyramidparser.py index 7f9d5c01..91018f41 100644 --- a/src/webargs/pyramidparser.py +++ b/src/webargs/pyramidparser.py @@ -113,6 +113,7 @@ def use_args( req=None, *, location=core.Parser.DEFAULT_LOCATION, + unknown=None, as_kwargs=False, validate=None, error_status_code=None, @@ -127,6 +128,8 @@ def use_args( which accepts a request and returns a `marshmallow.Schema`. :param req: The request object to parse. Pulled off of the view by default. :param str location: Where on the request to load values. + :param str unknown: A value to pass for ``unknown`` when calling the + schema's ``load`` method (marshmallow 3 only). :param bool as_kwargs: Whether to insert arguments as keyword arguments. :param callable validate: Validation function that receives the dictionary of parsed arguments. If the function returns ``False``, the parser @@ -155,6 +158,7 @@ def wrapper(obj, *args, **kwargs): argmap, req=request, location=location, + unknown=unknown, validate=validate, error_status_code=error_status_code, error_headers=error_headers, diff --git a/tests/test_core.py b/tests/test_core.py index a9a32e98..1995636d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -108,7 +108,11 @@ def test_parse(parser, web_request): @pytest.mark.skipif( MARSHMALLOW_VERSION_INFO[0] < 3, reason="unknown=... added in marshmallow3" ) -def test_parse_with_unknown_behavior_specified(parser, web_request): +@pytest.mark.parametrize( + "set_location", + ["schema_instance", "parse_call", "parser_default", "parser_class_default"], +) +def test_parse_with_unknown_behavior_specified(parser, web_request, set_location): # This is new in webargs 6.x ; it's the way you can "get back" the behavior # of webargs 5.x in which extra args are ignored from marshmallow import EXCLUDE, INCLUDE, RAISE @@ -119,17 +123,65 @@ class CustomSchema(Schema): username = fields.Field() password = fields.Field() + def parse_with_desired_behavior(value): + if set_location == "schema_instance": + if value is not None: + return parser.parse(CustomSchema(unknown=value), web_request) + else: + return parser.parse(CustomSchema(), web_request) + elif set_location == "parse_call": + return parser.parse(CustomSchema(), web_request, unknown=value) + elif set_location == "parser_default": + parser.unknown = value + return parser.parse(CustomSchema(), web_request) + elif set_location == "parser_class_default": + + class CustomParser(MockRequestParser): + DEFAULT_UNKNOWN = value + + return CustomParser().parse(CustomSchema(), web_request) + else: + raise NotImplementedError + # with no unknown setting or unknown=RAISE, it blows up with pytest.raises(ValidationError, match="Unknown field."): - parser.parse(CustomSchema(), web_request) + parse_with_desired_behavior(None) with pytest.raises(ValidationError, match="Unknown field."): - parser.parse(CustomSchema(unknown=RAISE), web_request) + parse_with_desired_behavior(RAISE) # with unknown=EXCLUDE the data is omitted - ret = parser.parse(CustomSchema(unknown=EXCLUDE), web_request) + ret = parse_with_desired_behavior(EXCLUDE) assert {"username": 42, "password": 42} == ret # with unknown=INCLUDE it is added even though it isn't part of the schema - ret = parser.parse(CustomSchema(unknown=INCLUDE), web_request) + ret = parse_with_desired_behavior(INCLUDE) + assert {"username": 42, "password": 42, "fjords": 42} == ret + + +@pytest.mark.skipif( + MARSHMALLOW_VERSION_INFO[0] < 3, reason="unknown=... added in marshmallow3" +) +def test_parse_with_explicit_unknown_overrides_schema(parser, web_request): + # this test ensures that if you specify unknown=... in your parse call (or + # use_args) it takes precedence over a setting in the schema object + from marshmallow import EXCLUDE, INCLUDE, RAISE + + web_request.json = {"username": 42, "password": 42, "fjords": 42} + + class CustomSchema(Schema): + username = fields.Field() + password = fields.Field() + + # setting RAISE in the parse call overrides schema setting + with pytest.raises(ValidationError, match="Unknown field."): + parser.parse(CustomSchema(unknown=EXCLUDE), web_request, unknown=RAISE) + with pytest.raises(ValidationError, match="Unknown field."): + parser.parse(CustomSchema(unknown=INCLUDE), web_request, unknown=RAISE) + + # and the reverse -- setting EXCLUDE or INCLUDE in the parse call overrides + # a schema with RAISE already set + ret = parser.parse(CustomSchema(unknown=RAISE), web_request, unknown=EXCLUDE) + assert {"username": 42, "password": 42} == ret + ret = parser.parse(CustomSchema(unknown=RAISE), web_request, unknown=INCLUDE) assert {"username": 42, "password": 42, "fjords": 42} == ret @@ -756,22 +808,18 @@ def test_warning_raised_if_schema_is_not_in_strict_mode(self, web_request, parse assert "strict=True" in str(warning.message) def test_use_kwargs_stacked(self, web_request, parser): + parse_kwargs = {} if MARSHMALLOW_VERSION_INFO[0] >= 3: from marshmallow import EXCLUDE - class PageSchema(Schema): - page = fields.Int() - - pageschema = PageSchema(unknown=EXCLUDE) - userschema = self.UserSchema(unknown=EXCLUDE) - else: - pageschema = {"page": fields.Int()} - userschema = self.UserSchema(**strict_kwargs) + parse_kwargs = {"unknown": EXCLUDE} web_request.json = {"email": "foo@bar.com", "password": "bar", "page": 42} - @parser.use_kwargs(pageschema, web_request) - @parser.use_kwargs(userschema, web_request) + @parser.use_kwargs({"page": fields.Int()}, web_request, **parse_kwargs) + @parser.use_kwargs( + self.UserSchema(**strict_kwargs), web_request, **parse_kwargs + ) def viewfunc(email, password, page): return {"email": email, "password": password, "page": page}