diff --git a/tests/cosmetic/assert_statements/input.py b/tests/cosmetic/assert_statements/input.py new file mode 100644 index 0000000..1f8245f --- /dev/null +++ b/tests/cosmetic/assert_statements/input.py @@ -0,0 +1,27 @@ +import unittest + + +class NotACase: + def not_a_case(self): + assert 1 == 2 + + +class WrongTestCase(unittest.TestCase): + def not_a_case(self): + assert isinstance(1, int) + + def test_without_self(): + assert 3 is 4 # comment + + +class CorrectTestCase(unittest.TestCase): + def test_regular(self): + assert self is not None # comment + assert isinstance( + self, # this is + unittest.TestCase, # multiline + ) + assert 1 == 1 # eq + + assert True +assert True diff --git a/tests/cosmetic/assert_statements/output.py b/tests/cosmetic/assert_statements/output.py new file mode 100644 index 0000000..aef9ebc --- /dev/null +++ b/tests/cosmetic/assert_statements/output.py @@ -0,0 +1,27 @@ +import unittest + + +class NotACase: + def not_a_case(self): + assert 1 == 2 + + +class WrongTestCase(unittest.TestCase): + def not_a_case(self): + assert isinstance(1, int) + + def test_without_self(): + assert 3 is 4 # comment + + +class CorrectTestCase(unittest.TestCase): + def test_regular(self): + self.assertIsNotNone(self) + self.assertIsInstance( + self, # this is + unittest.TestCase + ) + self.assertEqual(1, 1) + + assert True +assert True diff --git a/tests/test_teyit.py b/tests/test_teyit.py index e95045d..6e6ab80 100644 --- a/tests/test_teyit.py +++ b/tests/test_teyit.py @@ -7,7 +7,7 @@ TEST_DATA_DIR = Path(__file__).parent -class TeyitTestCase(unittest.TestCase): +class BaseCase: def assertRewrites(self, given, expected, case_count=None, **kwargs): with self.subTest(given=given, expected=expected, **kwargs): source, cases = teyit.refactor_until_deterministic(given) @@ -18,16 +18,18 @@ def assertRewrites(self, given, expected, case_count=None, **kwargs): def assertNotRewrites(self, given): self.assertRewrites(given, given, case_count=0) + +class TeyitTestCase(BaseCase, unittest.TestCase): def test_rewrite(self): func = ast.parse("self.assertTrue(x is None)", mode="eval").body - rewrite = teyit.Rewrite(func, "assertIsNone", [func.args[0].left]) + rewrite = teyit.FunctionRewrite(func, "assertIsNone", [func.args[0].left]) self.assertEqual( ast.unparse(rewrite.build_node()), "self.assertIsNone(x)" ) self.assertEqual(rewrite.get_arg_offset(), 0) func = ast.parse("self.assertTrue(a == b)", mode="eval").body - rewrite = teyit.Rewrite( + rewrite = teyit.FunctionRewrite( func, "assertEqual", [func.args[0].left, *func.args[0].comparators] ) self.assertEqual(rewrite.get_arg_offset(), 1) @@ -35,7 +37,7 @@ def test_rewrite(self): func = ast.parse( "self.assertTrue(x is None, message='XYZ')", mode="eval" ).body - rewrite = teyit.Rewrite(func, "assertIsNone", [func.args[0].left]) + rewrite = teyit.FunctionRewrite(func, "assertIsNone", [func.args[0].left]) self.assertEqual( ast.unparse(rewrite.build_node()), "self.assertIsNone(x, message='XYZ')", @@ -44,7 +46,7 @@ def test_rewrite(self): func = ast.parse( "self.assertIs(x, None, message='XYZ')", mode="eval" ).body - rewrite = teyit.Rewrite(func, "assertIsNone", [func.args[0]]) + rewrite = teyit.FunctionRewrite(func, "assertIsNone", [func.args[0]]) self.assertEqual(rewrite.get_arg_offset(), -1) def test_assert_rewriter_basic(self): @@ -211,6 +213,7 @@ def test_assert_rewriter_deprecated(self): ) def test_assert_rewriter_cosmetic(self): + self.maxDiff = None for case in (TEST_DATA_DIR / "cosmetic").iterdir(): self.assertRewrites( (case / "input.py").read_text(), @@ -219,5 +222,38 @@ def test_assert_rewriter_cosmetic(self): ) +_TEST_CASE_TEMPLATE = """ +class TestCase: + def test_my(self): + {0} +""" + + +class AssertStmtTestCase(BaseCase, unittest.TestCase): + def assertStmtRewrites(self, given, expected): + self.assertRewrites( + _TEST_CASE_TEMPLATE.format(given), + _TEST_CASE_TEMPLATE.format(expected), + ) + + def test_assert_stmt_to_call(self): + self.assertStmtRewrites( + "assert x == y", + "self.assertEqual(x, y)", + ) + self.assertStmtRewrites( + "assert x != y", + "self.assertNotEqual(x, y)", + ) + self.assertStmtRewrites( + "assert x is y", + "self.assertIs(x, y)", + ) + self.assertStmtRewrites( + "assert isinstance(x, y)", + "self.assertIsInstance(x, y)", + ) + + if __name__ == "__main__": unittest.main() diff --git a/teyit.py b/teyit.py index bfc9f59..7a4f2ce 100644 --- a/teyit.py +++ b/teyit.py @@ -46,8 +46,14 @@ } -@dataclass -class Rewrite: +@dataclass(frozen=True) +class FunctionRewrite: + """ + Used to rewrite test functions. + + Example: ``self.assertTrue(x == y)`` -> ``self.assertEqual(x, y)`` + """ + node: ast.Call func: str args: List[ast.AST] @@ -69,6 +75,34 @@ def get_arg_offset(self): return len(new_node.args + new_node.keywords) - prev_args +@dataclass(frozen=True) +class AssertStmtRewrite: + """ + Used to rewrite ``assert`` statement into function calls. + + Example: ``assert x == y`` -> ``self.assertEqual(x, y)`` + """ + + node: ast.Assert + expr: FunctionRewrite + + def __hash__(self): + return hash(id(self)) + + @lru_cache(maxsize=1) + def build_node(self): + new_node = self.expr.build_node() + new_node.lineno = self.node.lineno + new_node.col_offset = self.node.col_offset + new_node.end_lineno = self.node.end_lineno + new_node.end_col_offset = self.node.end_col_offset + return new_node + + @lru_cache(maxsize=1) + def get_arg_offset(self): + return 2 + + class _AssertRewriter(ast.NodeVisitor): def __init__(self, blacklist=frozenset(), *args, **kwargs): self.asserts = [] @@ -87,7 +121,7 @@ def visit_Call(self, node): with suppress(Exception): if node.func.attr in DEPRECATED_ALIASES: self.asserts.append( - Rewrite( + FunctionRewrite( node, DEPRECATED_ALIASES[node.func.attr], node.args ) ) @@ -133,7 +167,7 @@ def visit_assertTrue(self, node, positive=True): args = [*expr.args, *args] else: return None - return Rewrite(node, func, args) + return FunctionRewrite(node, func, args) def visit_assertFalse(self, node): return self.visit_assertTrue(node, positive=False) @@ -155,7 +189,7 @@ def visit_assertIs(self, node, positive=True): args = [left, *args] else: return None - return Rewrite(node, func, args) + return FunctionRewrite(node, func, args) def visit_assertIsNot(self, node): return self.visit_assertIs(node, positive=False) @@ -168,7 +202,63 @@ def visit_assertDictContainsSubset(self, node): ast.Dict(keys=[None, None], values=[right, left]), *args, ] - return Rewrite(node, func, args) + return FunctionRewrite(node, func, args) + + def visit_ClassDef(self, node): + for child_node in ast.walk(node): + if not isinstance( + child_node, (ast.FunctionDef, ast.AsyncFunctionDef) + ): + continue + + if not _looks_like_test(node.name) or not _looks_like_test( + child_node.name + ): + continue + + if not any(arg.arg == "self" for arg in child_node.args.args): + continue + + visitor = _AssertStmtFinder() + visitor.visit(child_node) + + for assert_stmt in visitor.stmts: + # with suppress(Exception): + new_expr = self.visit_assertTrue( + ast.Call( + ast.Attribute( + ast.Name("self", ast.Load), + "assertTrue", + ast.Load, + ), + [assert_stmt.test], + [], # TODO: support `msg` part + ) + ) + if new_expr is None: + continue + + self.asserts.append( + AssertStmtRewrite( + node=assert_stmt, + expr=new_expr, + ) + ) + self.generic_visit(node) + + +class _AssertStmtFinder(ast.NodeVisitor): + def __init__(self): + self.stmts = [] + + def visit_Assert(self, node): + # Is called with right context: inside an instance test method. + self.stmts.append(node) + self.generic_visit(node) + + +def _looks_like_test(name): + return "test" in name or "Test" in name class _FormattedUnparser(PreciseUnparser):