Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite assert statements, refs #12 #13

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/cosmetic/assert_statements/input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import unittest


class NotACase:
def not_a_case(self):
assert 1 == 2


class WrongTestCase(unittest.TestCase):
def not_a_case(self):
assert isinstance(1, int)

def test_without_self():
assert 3 is 4 # comment


class CorrectTestCase(unittest.TestCase):
def test_regular(self):
assert self is not None # comment
assert isinstance(
self, # this is
unittest.TestCase, # multiline
)
assert 1 == 1 # eq

assert True
assert True
27 changes: 27 additions & 0 deletions tests/cosmetic/assert_statements/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import unittest


class NotACase:
def not_a_case(self):
assert 1 == 2


class WrongTestCase(unittest.TestCase):
def not_a_case(self):
assert isinstance(1, int)

def test_without_self():
assert 3 is 4 # comment


class CorrectTestCase(unittest.TestCase):
def test_regular(self):
self.assertIsNotNone(self)
self.assertIsInstance(
self, # this is
unittest.TestCase
)
self.assertEqual(1, 1)

assert True
assert True
46 changes: 41 additions & 5 deletions tests/test_teyit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
TEST_DATA_DIR = Path(__file__).parent


class TeyitTestCase(unittest.TestCase):
class BaseCase:
def assertRewrites(self, given, expected, case_count=None, **kwargs):
with self.subTest(given=given, expected=expected, **kwargs):
source, cases = teyit.refactor_until_deterministic(given)
Expand All @@ -18,24 +18,26 @@ def assertRewrites(self, given, expected, case_count=None, **kwargs):
def assertNotRewrites(self, given):
self.assertRewrites(given, given, case_count=0)


class TeyitTestCase(BaseCase, unittest.TestCase):
def test_rewrite(self):
func = ast.parse("self.assertTrue(x is None)", mode="eval").body
rewrite = teyit.Rewrite(func, "assertIsNone", [func.args[0].left])
rewrite = teyit.FunctionRewrite(func, "assertIsNone", [func.args[0].left])
self.assertEqual(
ast.unparse(rewrite.build_node()), "self.assertIsNone(x)"
)
self.assertEqual(rewrite.get_arg_offset(), 0)

func = ast.parse("self.assertTrue(a == b)", mode="eval").body
rewrite = teyit.Rewrite(
rewrite = teyit.FunctionRewrite(
func, "assertEqual", [func.args[0].left, *func.args[0].comparators]
)
self.assertEqual(rewrite.get_arg_offset(), 1)

func = ast.parse(
"self.assertTrue(x is None, message='XYZ')", mode="eval"
).body
rewrite = teyit.Rewrite(func, "assertIsNone", [func.args[0].left])
rewrite = teyit.FunctionRewrite(func, "assertIsNone", [func.args[0].left])
self.assertEqual(
ast.unparse(rewrite.build_node()),
"self.assertIsNone(x, message='XYZ')",
Expand All @@ -44,7 +46,7 @@ def test_rewrite(self):
func = ast.parse(
"self.assertIs(x, None, message='XYZ')", mode="eval"
).body
rewrite = teyit.Rewrite(func, "assertIsNone", [func.args[0]])
rewrite = teyit.FunctionRewrite(func, "assertIsNone", [func.args[0]])
self.assertEqual(rewrite.get_arg_offset(), -1)

def test_assert_rewriter_basic(self):
Expand Down Expand Up @@ -211,6 +213,7 @@ def test_assert_rewriter_deprecated(self):
)

def test_assert_rewriter_cosmetic(self):
self.maxDiff = None
for case in (TEST_DATA_DIR / "cosmetic").iterdir():
self.assertRewrites(
(case / "input.py").read_text(),
Expand All @@ -219,5 +222,38 @@ def test_assert_rewriter_cosmetic(self):
)


_TEST_CASE_TEMPLATE = """
class TestCase:
def test_my(self):
{0}
"""


class AssertStmtTestCase(BaseCase, unittest.TestCase):
def assertStmtRewrites(self, given, expected):
self.assertRewrites(
_TEST_CASE_TEMPLATE.format(given),
_TEST_CASE_TEMPLATE.format(expected),
)

def test_assert_stmt_to_call(self):
self.assertStmtRewrites(
"assert x == y",
"self.assertEqual(x, y)",
)
self.assertStmtRewrites(
"assert x != y",
"self.assertNotEqual(x, y)",
)
self.assertStmtRewrites(
"assert x is y",
"self.assertIs(x, y)",
)
self.assertStmtRewrites(
"assert isinstance(x, y)",
"self.assertIsInstance(x, y)",
)


if __name__ == "__main__":
unittest.main()
102 changes: 96 additions & 6 deletions teyit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@
}


@dataclass
class Rewrite:
@dataclass(frozen=True)
class FunctionRewrite:
"""
Used to rewrite test functions.

Example: ``self.assertTrue(x == y)`` -> ``self.assertEqual(x, y)``
"""

node: ast.Call
func: str
args: List[ast.AST]
Expand All @@ -69,6 +75,34 @@ def get_arg_offset(self):
return len(new_node.args + new_node.keywords) - prev_args


@dataclass(frozen=True)
class AssertStmtRewrite:
"""
Used to rewrite ``assert`` statement into function calls.

Example: ``assert x == y`` -> ``self.assertEqual(x, y)``
"""

node: ast.Assert
expr: FunctionRewrite

def __hash__(self):
return hash(id(self))

@lru_cache(maxsize=1)
def build_node(self):
new_node = self.expr.build_node()
new_node.lineno = self.node.lineno
new_node.col_offset = self.node.col_offset
new_node.end_lineno = self.node.end_lineno
new_node.end_col_offset = self.node.end_col_offset
return new_node

@lru_cache(maxsize=1)
def get_arg_offset(self):
return 2


class _AssertRewriter(ast.NodeVisitor):
def __init__(self, blacklist=frozenset(), *args, **kwargs):
self.asserts = []
Expand All @@ -87,7 +121,7 @@ def visit_Call(self, node):
with suppress(Exception):
if node.func.attr in DEPRECATED_ALIASES:
self.asserts.append(
Rewrite(
FunctionRewrite(
node, DEPRECATED_ALIASES[node.func.attr], node.args
)
)
Expand Down Expand Up @@ -133,7 +167,7 @@ def visit_assertTrue(self, node, positive=True):
args = [*expr.args, *args]
else:
return None
return Rewrite(node, func, args)
return FunctionRewrite(node, func, args)

def visit_assertFalse(self, node):
return self.visit_assertTrue(node, positive=False)
Expand All @@ -155,7 +189,7 @@ def visit_assertIs(self, node, positive=True):
args = [left, *args]
else:
return None
return Rewrite(node, func, args)
return FunctionRewrite(node, func, args)

def visit_assertIsNot(self, node):
return self.visit_assertIs(node, positive=False)
Expand All @@ -168,7 +202,63 @@ def visit_assertDictContainsSubset(self, node):
ast.Dict(keys=[None, None], values=[right, left]),
*args,
]
return Rewrite(node, func, args)
return FunctionRewrite(node, func, args)

def visit_ClassDef(self, node):
for child_node in ast.walk(node):
if not isinstance(
child_node, (ast.FunctionDef, ast.AsyncFunctionDef)
):
continue

if not _looks_like_test(node.name) or not _looks_like_test(
child_node.name
):
continue

if not any(arg.arg == "self" for arg in child_node.args.args):
continue

visitor = _AssertStmtFinder()
visitor.visit(child_node)

for assert_stmt in visitor.stmts:
# with suppress(Exception):
new_expr = self.visit_assertTrue(
ast.Call(
ast.Attribute(
ast.Name("self", ast.Load),
"assertTrue",
ast.Load,
),
[assert_stmt.test],
[], # TODO: support `msg` part
)
)
if new_expr is None:
continue

self.asserts.append(
AssertStmtRewrite(
node=assert_stmt,
expr=new_expr,
)
)
self.generic_visit(node)


class _AssertStmtFinder(ast.NodeVisitor):
def __init__(self):
self.stmts = []

def visit_Assert(self, node):
# Is called with right context: inside an instance test method.
self.stmts.append(node)
self.generic_visit(node)


def _looks_like_test(name):
return "test" in name or "Test" in name


class _FormattedUnparser(PreciseUnparser):
Expand Down