diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 340fe83ed..e59e56ed2 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -831,6 +831,14 @@ class String(Field): "invalid_utf8": "Not a valid utf-8 string.", } + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Insert validation into self.validators so that multiple errors can be stored. + validator = validate.ProhibitNullCharactersValidator( + error=self.error_messages["invalid"] + ) + self.validators.insert(0, validator) + def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]: if value is None: return None diff --git a/src/marshmallow/validate.py b/src/marshmallow/validate.py index e57073a4e..74d9a1ddc 100644 --- a/src/marshmallow/validate.py +++ b/src/marshmallow/validate.py @@ -185,6 +185,32 @@ def __call__(self, value) -> typing.Any: return value +class ProhibitNullCharactersValidator(Validator): + """Validate string not having Null Character + + :param error: Error message to raise in case of a validation error. Can be + interpolated with `{input}`. + """ + + NULL_REGEX = re.compile( + r"\0", + ) + + def __init__(self, *, error: typing.Optional[str] = None): + self.error = error or self.default_message # type: str + + def _format_error(self, value) -> typing.Any: + return self.error.format(input=value) + + def __call__(self, value) -> typing.Any: + message = self._format_error(value) + + if value and self.NULL_REGEX.search(str(value)): + raise ValidationError(message) + + return value + + class Range(Validator): """Validator which succeeds if the value passed to it is within the specified range. If ``min`` is not specified, or is specified as `None`, diff --git a/tests/test_fields.py b/tests/test_fields.py index 4e5e4cc01..4955e04c5 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -92,6 +92,13 @@ class MySchema(Schema): result = MySchema().dump({"name": "Monty", "foo": 42}) assert result == {"_NaMe": "Monty"} + def test_string_field_null_char(self): + class MySchema(Schema): + name = fields.String() + + with pytest.raises(ValidationError): + MySchema().load({"name": "a\0b"}) + class TestParentAndName: class MySchema(Schema):