Skip to content

Commit

Permalink
Flag invalid regexes (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Sep 27, 2024
1 parent a925ba7 commit a5bb63c
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## Unreleased

- Flag invalid regexes in arguments to functions like
`re.search` (#816)

## Version 0.13.1 (August 7, 2024)

- Use Trusted Publishing to publish releases (#806)
Expand Down
1 change: 1 addition & 0 deletions pyanalyze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
used(extensions.show_error)
used(extensions.has_extra_keys)
used(extensions.EnumName)
used(extensions.ValidRegex)
used(value.UNRESOLVED_VALUE) # keeping it around for now just in case
used(reexport)
used(patma)
Expand Down
9 changes: 8 additions & 1 deletion pyanalyze/arg_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def unwrap(cls, typ: type, options: Options) -> type:
return typ


_BUILTIN_KNOWN_SIGNATURES = []
_BUILTIN_KNOWN_SIGNATURES = [implementation.get_default_argspecs_with_cache]

try:
import _pytest
Expand Down Expand Up @@ -784,6 +784,13 @@ def _uncached_get_argspec(
obj, allow_call=allow_call, type_params=type_params
)
if argspec is not None:
if impl is not None:
if isinstance(argspec, OverloadedSignature):
return OverloadedSignature(
[replace(sig, impl=impl) for sig in argspec.signatures]
)
else:
return replace(argspec, impl=impl)
return argspec

if is_typeddict(obj) and not is_typing_name(obj, "TypedDict"):
Expand Down
1 change: 1 addition & 0 deletions pyanalyze/error_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __iter__(self) -> Iterator[Error]:
Error("generator_return", "Generator must return an iterable"),
Error("unsafe_comparison", "Non-overlapping equality checks"),
Error("must_use", "Value cannot be discarded"),
Error("invalid_regex", "Invalid regular expression"),
]
)

Expand Down
21 changes: 21 additions & 0 deletions pyanalyze/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ def _is_disallowed(self, value: "Value") -> bool:
)


@dataclass(frozen=True)
class ValidRegex(CustomCheck):
"""Custom check that allows only values that are valid regular expressions.
Example::
def func(arg: Annotated[str, ValidRegex()]) -> None:
...
func(".*") # ok
func("[") # error
"""

def can_assign(self, value: "Value", ctx: "CanAssignContext") -> "CanAssign":
error = pyanalyze.implementation.check_regex_in_value(value)
if error is not None:
return error
return {}


class _AsynqCallableMeta(type):
def __getitem__(
self, params: Tuple[Union[Literal[Ellipsis], List[object]], object]
Expand Down
69 changes: 69 additions & 0 deletions pyanalyze/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections
import collections.abc
import inspect
import re
import typing
from itertools import product
from typing import (
Expand All @@ -20,6 +21,8 @@
import qcore
import typing_extensions

import pyanalyze

from . import runtime
from .annotated_types import MaxLen, MinLen
from .annotations import type_from_value
Expand All @@ -31,7 +34,9 @@
from .signature import (
ANY_SIGNATURE,
CallContext,
ConcreteSignature,
ImplReturn,
OverloadedSignature,
ParameterKind,
Signature,
SigParameter,
Expand Down Expand Up @@ -2230,3 +2235,67 @@ def get_default_argspecs() -> Dict[object, Signature]:
)
signatures.append(sig)
return {sig.callable: sig for sig in signatures}


def check_regex(pattern: Union[str, bytes]) -> Optional[CanAssignError]:
try:
# TODO allow this without the useless isinstance()
if isinstance(pattern, str):
re.compile(pattern)
else:
re.compile(pattern)
except re.error as e:
return CanAssignError(
f"Invalid regex pattern: {e}", error_code=ErrorCode.invalid_regex
)
return None


def check_regex_in_value(value: Value) -> Optional[CanAssignError]:
errors = []
for subval in flatten_values(value):
if not isinstance(subval, KnownValue):
continue
if not isinstance(subval.val, (str, bytes)):
continue
maybe_error = check_regex(subval.val)
if maybe_error is not None:
errors.append(maybe_error)
if errors:
if len(errors) == 1:
return errors[0]
return pyanalyze.value.CanAssignError(
"Invalid regex", errors, pyanalyze.error_code.ErrorCode.invalid_regex
)
return None


def _re_impl_with_pattern(ctx: CallContext) -> Value:
pattern = ctx.vars["pattern"]
error = check_regex_in_value(pattern)
if error is not None:
ctx.show_error(error.message, error_code=ErrorCode.invalid_regex, arg="pattern")
return ctx.inferred_return_value


def get_default_argspecs_with_cache(
asc: "pyanalyze.arg_spec.ArgSpecCache",
) -> Dict[object, ConcreteSignature]:
sigs = {}
for func in (
re.compile,
re.search,
re.match,
re.fullmatch,
re.split,
re.findall,
re.finditer,
re.sub,
re.subn,
):
sig = asc.get_argspec(func, impl=_re_impl_with_pattern)
assert isinstance(
sig, (Signature, OverloadedSignature)
), f"failed to find signature for {func}: {sig}"
sigs[func] = sig
return sigs
11 changes: 10 additions & 1 deletion pyanalyze/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ class CallContext:
node: Optional[ast.AST]
"""AST node corresponding to the function call. Useful for
showing errors."""
sig: "Signature"
inferred_return_value: Value

def ast_for_arg(self, arg: str) -> Optional[ast.AST]:
composite = self.composite_for_arg(arg)
Expand Down Expand Up @@ -1332,8 +1334,14 @@ def check_call_with_bound_args(
visitor=ctx.visitor,
composites=composites,
node=ctx.node,
sig=self,
inferred_return_value=return_value,
)
return_value = self.impl(call_ctx)
with ctx.visitor.catch_errors() as caught_errors:
return_value = self.impl(call_ctx)
if caught_errors:
ctx.visitor.show_caught_errors(caught_errors)
had_error = True
elif self.evaluator is not None:
varmap = {
param: composite.value
Expand All @@ -1347,6 +1355,7 @@ def check_call_with_bound_args(
)
return_value, errors = self.evaluator.evaluate(eval_ctx)
for error in errors:
had_error = True
error_node = None
if error.argument is not None:
composite = bound_args[error.argument][1]
Expand Down
31 changes: 31 additions & 0 deletions pyanalyze/test_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,3 +1476,34 @@ def test_namedtuple_after_3_13(self):
def capybara() -> None:
# on 3.13+ we get a second error from calling the runtime
NamedTuple("x", None, y=int) # E: incompatible_call # E: incompatible_call


class TestRegex(TestNameCheckVisitorBase):
@assert_passes()
def test_compile(self):
import re

def capybara():
re.compile("a")
re.compile(b"a")
re.compile("[") # E: incompatible_call
re.compile(b"[") # E: incompatible_call

re.sub(r"a", "b", "c")
re.sub(rb"(", b"b", b"c") # E: incompatible_call

re.match(r"a", "b")
re.match(rb"(", b"b") # E: incompatible_call

@assert_passes()
def test_extension(self):
from typing_extensions import Annotated

from pyanalyze.extensions import ValidRegex

def f(x: Annotated[str, ValidRegex()]):
pass

def capybara():
f("x")
f(r"[") # E: invalid_regex

0 comments on commit a5bb63c

Please sign in to comment.