diff --git a/tests/test_async_code_match.py b/tests/test_async_code_match.py index 0ec4550d4..8d13aa04e 100644 --- a/tests/test_async_code_match.py +++ b/tests/test_async_code_match.py @@ -73,38 +73,50 @@ def _execute_code_match(self, source, asource): .replace("ASYNC_", "") ) asource = re.sub(" *?# type: ignore", "", asource) + if source != asource: + pass self.assertEqual(source, asource) def test_code_match_for_async_methods(self): for fpath in (pathlib.Path(__file__).parent.parent / "google").rglob("*.py"): + print(fpath) if fpath.name in EXEMPT_FILES or any([d in fpath.parts for d in EXEMPT_DIRS]): continue # print(f"Checking {fpath.absolute()}") - code_match_funcs: dict[str, ast.AST] = {} + funcs: dict[str, ast.AST] = {} + async_funcs: dict[str, ast.AST] = {} source = fpath.read_text() source_nodes = ast.parse(source) for node in ast.walk(source_nodes): - if isinstance( - node, (ast.FunctionDef, ast.AsyncFunctionDef) - ) and not node.name.startswith("__"): - name = node.name[:-6] if node.name.endswith("_async") else node.name - if name in EXEMPT_FUNCTIONS or self._inspect_decorator_exemption(node, fpath): - continue - # print(f"Checking {node.name}") - - if func_name := code_match_funcs.pop(name, None): - snode, anode = ( - (func_name, node) - if isinstance(node, ast.AsyncFunctionDef) - else (node, func_name) - ) - func_source = self._maybe_trim_docstring(snode) - func_asource = self._maybe_trim_docstring(anode) - self._execute_code_match(func_source, func_asource) - # print(f"Matched {node.name}") + + + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if node.name.startswith("__"): + continue + + if isinstance(node, ast.AsyncFunctionDef): + is_async = True + code_match_funcs = funcs + else: + is_async = False + code_match_funcs = async_funcs + + name = node.name.removesuffix("_async") + if name in EXEMPT_FUNCTIONS or self._inspect_decorator_exemption(node, fpath): + continue + + if matched_node := code_match_funcs.pop(name, None): + if is_async: + snode, anode = (matched_node, node) else: - code_match_funcs[node.name] = node + anode, snode = (matched_node, node) + func_source = self._maybe_trim_docstring(snode) + func_asource = self._maybe_trim_docstring(anode) + self._execute_code_match(func_source, func_asource) + else: + code_match_funcs[node.name] = node if __name__ == "__main__":