Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added additional type hints #87

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions src/benchmark/executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,24 @@

import atheris
import contextlib
from typing import Any, Callable, Iterator, List, Tuple, TypeVar

# Use new atheris instrumentation only if on new atheris
if "instrument_func" in dir(atheris):
instrument_func = atheris.instrument_func
instrument_imports = atheris.instrument_imports
instrument_all = atheris.instrument_all
else:
T = TypeVar("T")

def instrument_func(x):
def instrument_func(x: Callable[..., T]) -> Callable[..., T]:
return x

def instrument_all():
pass

@contextlib.contextmanager
def instrument_imports(*args, **kwargs):
def instrument_imports(*args: Any, **kwargs: Any) -> Iterator[None]:
yield None


Expand All @@ -46,14 +48,14 @@ def instrument_imports(*args, **kwargs):
import io


def _set_nonblocking(fd):
def _set_nonblocking(fd: int):
"""Set the specified fd to a nonblocking mode."""
oflags = fcntl.fcntl(fd, fcntl.F_GETFL)
nflags = oflags | os.O_NONBLOCK
fcntl.fcntl(fd, fcntl.F_SETFL, nflags)


def _benchmark_child(test_one_input, num_runs, pipe, args, inst_all):
def _benchmark_child(test_one_input: Callable[[bytes], None], num_runs: int, pipe: Tuple[int, int], args: List[str], inst_all: bool):
os.close(pipe[0])
os.dup2(pipe[1], 1)
os.dup2(pipe[1], 2)
Expand All @@ -64,7 +66,7 @@ def _benchmark_child(test_one_input, num_runs, pipe, args, inst_all):
counter = [0]
start = time.time()

def wrapped_test_one_input(data):
def wrapped_test_one_input(data: bytes):
counter[0] += 1
if counter[0] == num_runs:
print(f"\nbenchmark_duration={time.time() - start}")
Expand All @@ -76,11 +78,11 @@ def wrapped_test_one_input(data):
assert False # Does not return


def run_benchmark(test_one_input,
num_runs,
timeout=10,
inst_all=False,
args=[]):
def run_benchmark(test_one_input: Callable[[bytes], None],
num_runs: int,
timeout: float = 10,
inst_all: bool = False,
args: List[str] = []):
"""Fuzz test_one_input() in a subprocess.

This forks a child, and in the child, runs atheris.Setup(test_one_input) and
Expand Down Expand Up @@ -140,7 +142,7 @@ def run_benchmark(test_one_input,


@instrument_func
def low_cyclomatic(data):
def low_cyclomatic(data: bytes):
x = 0
x = 1
x = 2
Expand Down Expand Up @@ -244,7 +246,7 @@ def low_cyclomatic(data):


@instrument_func
def high_cyclomatic(data):
def high_cyclomatic(data: bytes):
for c in data:
if c == 0:
c = 38
Expand Down Expand Up @@ -760,15 +762,15 @@ def high_cyclomatic(data):
c = 7


def json_fuzz(data):
def json_fuzz(data: bytes):
try:
json.loads(data.decode("utf-8", "surrogatepass"))
except Exception as e:
pass


@instrument_func
def zip_fuzz(data):
def zip_fuzz(data: bytes):
try:
with io.BytesIO(data) as f:
pz = zipfile.ZipFile(f)
Expand Down
65 changes: 33 additions & 32 deletions src/coverage_g3test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dis
import re
import unittest
from typing import Any, Tuple
from unittest import mock

with atheris.instrument_imports():
Expand All @@ -29,7 +30,7 @@


@atheris.instrument_func
def if_func(a):
def if_func(a: float) -> int:
x = a
if x:
return 2
Expand All @@ -38,89 +39,89 @@ def if_func(a):


@atheris.instrument_func
def cmp_less(a, b):
def cmp_less(a: float, b: float):
return a < b


@atheris.instrument_func
def cmp_greater(a, b):
def cmp_greater(a: float, b: float):
return a > b


@atheris.instrument_func
def cmp_equal_nested(a, b, c):
def cmp_equal_nested(a: float, b: float, c: float) -> bool:
return (a == b) == c


@atheris.instrument_func
def cmp_const_less(a):
def cmp_const_less(a: float) -> bool:
return 1 < a


@atheris.instrument_func
def cmp_const_less_inverted(a):
def cmp_const_less_inverted(a: float) -> bool:
return a < 1


@atheris.instrument_func
def decorator_instrumented(x):
def decorator_instrumented(x: int):
return 2 * x


@atheris.instrument_func
def while_loop(a):
def while_loop(a: float):
while a:
a -= 1


@atheris.instrument_func
def regex_match(re_obj, a):
def regex_match(re_obj: re.Pattern, a: str):
re_obj.match(a)


@atheris.instrument_func
def starts_with(s, prefix):
def starts_with(s: str, prefix: str):
s.startswith(prefix)


@atheris.instrument_func
def ends_with(s, suffix):
def ends_with(s: str, suffix: str):
s.endswith(suffix)


# Verifying that no tracing happens when var args are passed in to
# startswith method calls
@atheris.instrument_func
def starts_with_var_args(s, *args):
def starts_with_var_args(s: str, *args: Any):
s.startswith(*args)


# Verifying that no tracing happens when var args are passed in to
# endswith method calls
@atheris.instrument_func
def ends_with_var_args(s, *args):
def ends_with_var_args(s: str, *args: Any):
s.startswith(*args)


class FakeStr:

def startswith(self, s, prefix):
def startswith(self, s: str, prefix: str):
pass

def endswith(self, s, suffix):
def endswith(self, s: str, suffix: str):
pass


# Verifying that even though this code gets patched, no tracing happens
@atheris.instrument_func
def fake_starts_with(s, prefix):
def fake_starts_with(s: str, prefix: str):
fake_str = FakeStr()
fake_str.startswith(s=s, prefix=prefix)


# Verifying that even though this code gets patched, no tracing happens
@atheris.instrument_func
def fake_ends_with(s, suffix):
def fake_ends_with(s: str, suffix: str):
fake_str = FakeStr()
fake_str.endswith(s, suffix)

Expand Down Expand Up @@ -161,16 +162,16 @@ def multi_instrumented(x):
@mock.patch.object(atheris, "_trace_branch")
class CoverageTest(unittest.TestCase):

def testImport(self, trace_branch_mock, trace_cmp_mock,
trace_regex_match_mock):
def testImport(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock,
trace_regex_match_mock: mock.MagicMock):
trace_cmp_mock.side_effect = original_trace_cmp

trace_branch_mock.assert_not_called()
Sequence.load(b"0\0")
trace_branch_mock.assert_called()

def testBranch(self, trace_branch_mock, trace_cmp_mock,
trace_regex_match_mock):
def testBranch(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock,
trace_regex_match_mock: mock.MagicMock):
trace_branch_mock.assert_not_called()
if_func(True)
first_call_set = trace_branch_mock.call_args_list
Expand All @@ -188,14 +189,14 @@ def testBranch(self, trace_branch_mock, trace_cmp_mock,
self.assertNotEqual(first_call_set, third_call_set)

def testWhile(
self, trace_branch_mock, trace_cmp_mock, trace_regex_match_mock
self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, trace_regex_match_mock: mock.MagicMock
):
trace_branch_mock.assert_not_called()
while_loop(1)
trace_branch_mock.assert_called()

def testRegex(self, trace_branch_mock, trace_cmp_mock,
trace_regex_match_mock):
def testRegex(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock,
trace_regex_match_mock: mock.MagicMock):
trace_branch_mock.reset_mock()
trace_branch_mock.assert_not_called()
trace_regex_match_mock.assert_not_called()
Expand All @@ -204,7 +205,7 @@ def testRegex(self, trace_branch_mock, trace_cmp_mock,
trace_regex_match_mock.assert_called()

def testStrMethods(
self, trace_branch_mock, trace_cmp_mock, trace_regex_match_mock
self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock, trace_regex_match_mock: mock.MagicMock
):
trace_branch_mock.assert_not_called()
trace_regex_match_mock.assert_not_called()
Expand Down Expand Up @@ -252,16 +253,16 @@ def testStrMethods(
trace_regex_match_mock.assert_not_called()
trace_regex_match_mock.reset_mock()

def assertTraceCmpWas(self, call_args, left, right, op, left_is_const):
def assertTraceCmpWas(self, call_args: Tuple[int, int, int, int, bool], left: int, right: int, op: str, left_is_const: bool):
"""Compare a _trace_cmp call to expected values."""
# call_args: tuple(left, right, opid, idx, left_is_const)
self.assertEqual(call_args[0], left)
self.assertEqual(call_args[1], right)
self.assertEqual(dis.cmp_op[call_args[2]], op)
self.assertEqual(call_args[4], left_is_const)

def testCompare(self, trace_branch_mock, trace_cmp_mock,
trace_regex_match_mock):
def testCompare(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock,
trace_regex_match_mock: mock.MagicMock):
trace_cmp_mock.side_effect = original_trace_cmp

self.assertTrue(cmp_less(1, 2))
Expand Down Expand Up @@ -297,8 +298,8 @@ def testCompare(self, trace_branch_mock, trace_cmp_mock,
self.assertNotEqual(second_cmp_idx, fifth_cmp_idx)
self.assertNotEqual(fourth_cmp_idx, fifth_cmp_idx)

def testConstCompare(self, trace_branch_mock, trace_cmp_mock,
trace_regex_match_mock):
def testConstCompare(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock,
trace_regex_match_mock: mock.MagicMock):
trace_cmp_mock.side_effect = original_trace_cmp

self.assertTrue(cmp_const_less(2))
Expand All @@ -309,8 +310,8 @@ def testConstCompare(self, trace_branch_mock, trace_cmp_mock,
self.assertTraceCmpWas(trace_cmp_mock.call_args[0], 1, 3, ">", True)
trace_cmp_mock.reset_mock()

def testInstrumentationAppliedOnce(self, trace_branch_mock, trace_cmp_mock,
trace_regex_match_mock):
def testInstrumentationAppliedOnce(self, trace_branch_mock: mock.MagicMock, trace_cmp_mock: mock.MagicMock,
trace_regex_match_mock: mock.MagicMock):
trace_branch_mock.assert_not_called()
multi_instrumented(7)
trace_branch_mock.assert_called_once()
Expand Down
Loading
Loading