From 3ffe85941df66ca7a01903517cda198b1f62ef15 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 5 Feb 2025 08:56:46 +0100 Subject: [PATCH] Remove non-used functionality from py2fgen (#656) - Remove special handling for wrapping gtx.Program (if needed at some point, the cleaner way would be to first wrap the program and then use the standard py2fgen) - Remove type annotations from the wrapper (they were wrong and required extra logic to extract them) - Remove copying over the imports from the wrapped module to the wrapper, unclear why that was needed in the past - Enable mypy on `icon4py.tools` and add type annotations - Move `unpack` and `int_array_to_bool_array` to `wrapper_util` module and import from there (they are now required at runtime). --- .pre-commit-config.yaml | 22 +-- pyproject.toml | 2 +- tools/src/icon4py/tools/common/metadata.py | 4 +- tools/src/icon4py/tools/icon4pygen/cli.py | 4 +- tools/src/icon4py/tools/liskov/cli.py | 4 +- .../liskov/codegen/integration/template.py | 2 +- .../tools/liskov/parsing/validation.py | 2 +- tools/src/icon4py/tools/py2fgen/cli.py | 2 +- tools/src/icon4py/tools/py2fgen/generate.py | 1 - tools/src/icon4py/tools/py2fgen/parsing.py | 131 ++---------------- tools/src/icon4py/tools/py2fgen/plugin.py | 106 -------------- tools/src/icon4py/tools/py2fgen/settings.py | 45 ++++-- tools/src/icon4py/tools/py2fgen/template.py | 80 +++-------- tools/src/icon4py/tools/py2fgen/utils.py | 6 +- .../icon4py/tools/py2fgen/wrapper_utils.py | 121 ++++++++++++++++ .../tools/py2fgen/wrappers/debug_utils.py | 2 +- tools/tests/py2fgen/test_cffi.py | 5 +- tools/tests/py2fgen/test_cli.py | 13 +- tools/tests/py2fgen/test_codegen.py | 118 +++------------- tools/tests/py2fgen/test_parsing.py | 29 +--- 20 files changed, 246 insertions(+), 453 deletions(-) create mode 100644 tools/src/icon4py/tools/py2fgen/wrapper_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aa9dd0053..db27937d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,14 +68,14 @@ repos: - id: tach name: Check inter-package dependencies -# TODO(egparedes): fix and activate mypy hook -# - repo: local -# hooks: -# - id: mypy -# name: mypy static type checker -# entry: bash -c 'mypy tools/src/icon4py/tools model/common/src/icon4py/model/common' -# language: system -# types_or: [python, pyi] -# pass_filenames: false -# require_serial: true -# stages: [pre-commit] + +- repo: local + hooks: + - id: mypy + name: mypy static type checker + entry: bash -c 'mypy tools/src/icon4py/tools' # model/common/src/icon4py/model/common # TODO(egparedes): fix and activate mypy hook for all packages + language: system + types_or: [python, pyi] + pass_filenames: false + require_serial: true + stages: [pre-commit] diff --git a/pyproject.toml b/pyproject.toml index 5006f5ff5..2235b3547 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -171,7 +171,7 @@ disallow_untyped_defs = true # '^tests/liskov/*.py', # '^tests/py2f/*.py', # ] -ignore_missing_imports = true +ignore_missing_imports = false implicit_reexport = true install_types = true non_interactive = true diff --git a/tools/src/icon4py/tools/common/metadata.py b/tools/src/icon4py/tools/common/metadata.py index 086357421..0a18eb164 100644 --- a/tools/src/icon4py/tools/common/metadata.py +++ b/tools/src/icon4py/tools/common/metadata.py @@ -58,7 +58,7 @@ class DummyConnectivity(Connectivity): neighbor_axis: Dimension = Dimension("unused") index_type: type[int] = int - def mapped_index(self, cur_index, neigh_index) -> int: + def mapped_index(self, cur_index, neigh_index) -> int: # type: ignore[no-untyped-def] # code will disappear with next gt4py version raise AssertionError("Unreachable") return 0 @@ -89,7 +89,7 @@ def _get_field_infos(fvprog: Program) -> dict[str, FieldInfo]: assert all( _is_list_of_names(body.args) for body in fvprog.past_stage.past_node.body ), "Found unsupported expression in input arguments." - input_arg_ids = set(arg.id for body in fvprog.past_stage.past_node.body for arg in body.args) # type: ignore[attr-defined] # Checked in the assert + input_arg_ids = set(arg.id for body in fvprog.past_stage.past_node.body for arg in body.args) out_args = (body.kwargs["out"] for body in fvprog.past_stage.past_node.body) output_fields = [] diff --git a/tools/src/icon4py/tools/icon4pygen/cli.py b/tools/src/icon4py/tools/icon4pygen/cli.py index 1121f4f30..b3adcc100 100644 --- a/tools/src/icon4py/tools/icon4pygen/cli.py +++ b/tools/src/icon4py/tools/icon4pygen/cli.py @@ -25,7 +25,9 @@ class ModuleType(click.ParamType): f"{dycore_import_path}.mo_velocity_advection_stencil_", ] - def shell_complete(self, ctx, param, incomplete): + def shell_complete( + self, ctx: click.Context, param: click.Parameter, incomplete: str + ) -> list[click.shell_completion.CompletionItem]: if len(incomplete) > 0 and incomplete.endswith(":"): completions = [incomplete + incomplete[:-1].split(".")[-1]] else: diff --git a/tools/src/icon4py/tools/liskov/cli.py b/tools/src/icon4py/tools/liskov/cli.py index 11e360f19..0eea0c303 100644 --- a/tools/src/icon4py/tools/liskov/cli.py +++ b/tools/src/icon4py/tools/liskov/cli.py @@ -23,7 +23,9 @@ logger = setup_logger(__name__) -def split_comma(ctx, param, value) -> Optional[tuple[str]]: +def split_comma( + ctx: click.Context, param: click.Parameter, value: str +) -> Optional[tuple[str, ...]]: return tuple(v.strip() for v in value.split(",")) if value else None diff --git a/tools/src/icon4py/tools/liskov/codegen/integration/template.py b/tools/src/icon4py/tools/liskov/codegen/integration/template.py index 4cd1ab668..a19fa450a 100644 --- a/tools/src/icon4py/tools/liskov/codegen/integration/template.py +++ b/tools/src/icon4py/tools/liskov/codegen/integration/template.py @@ -175,7 +175,7 @@ class EndStencilStatement(EndBasicStencilStatement): noprofile: Optional[bool] noaccenddata: Optional[bool] - def __post_init__(self, *args, **kwargs) -> None: + def __post_init__(self, *args: Any, **kwargs: Any) -> None: all_fields = [Field(**asdict(f)) for f in self.stencil_data.fields] self.bounds_fields = BoundsFields(**asdict(self.stencil_data.bounds)) self.name = self.stencil_data.name diff --git a/tools/src/icon4py/tools/liskov/parsing/validation.py b/tools/src/icon4py/tools/liskov/parsing/validation.py index f714e28f5..5a80c3d03 100644 --- a/tools/src/icon4py/tools/liskov/parsing/validation.py +++ b/tools/src/icon4py/tools/liskov/parsing/validation.py @@ -158,7 +158,7 @@ def _validate_stencil_directives(self, directives: Sequence[ts.ParsedDirective]) def _identify_unbalanced_directives( directives: Sequence[ts.ParsedDirective], directive_types: tuple[Type[ts.ParsedDirective], ...], - ): + ) -> None: directive_counts: dict[str, int] = defaultdict(int) for directive in directives: if isinstance(directive, directive_types): diff --git a/tools/src/icon4py/tools/py2fgen/cli.py b/tools/src/icon4py/tools/py2fgen/cli.py index 20c101f0b..0485927e4 100644 --- a/tools/src/icon4py/tools/py2fgen/cli.py +++ b/tools/src/icon4py/tools/py2fgen/cli.py @@ -21,7 +21,7 @@ from icon4py.tools.py2fgen.settings import GT4PyBackend -def parse_comma_separated_list(ctx, param, value) -> list[str]: +def parse_comma_separated_list(ctx: click.Context, param: click.Parameter, value: str) -> list[str]: # Splits the input string by commas and strips any leading/trailing whitespace from the strings return [item.strip() for item in value.split(",")] diff --git a/tools/src/icon4py/tools/py2fgen/generate.py b/tools/src/icon4py/tools/py2fgen/generate.py index 8af6e4aa2..eb74deb24 100644 --- a/tools/src/icon4py/tools/py2fgen/generate.py +++ b/tools/src/icon4py/tools/py2fgen/generate.py @@ -61,7 +61,6 @@ def generate_python_wrapper( module_name=plugin.module_name, plugin_name=plugin.plugin_name, functions=plugin.functions, - imports=plugin.imports, backend=backend, debug_mode=debug_mode, limited_area=limited_area, diff --git a/tools/src/icon4py/tools/py2fgen/parsing.py b/tools/src/icon4py/tools/py2fgen/parsing.py index 6100e53ca..a40afa933 100644 --- a/tools/src/icon4py/tools/py2fgen/parsing.py +++ b/tools/src/icon4py/tools/py2fgen/parsing.py @@ -6,147 +6,40 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import ast import importlib -import inspect -import re from inspect import signature, unwrap from types import ModuleType from typing import Callable, List -from gt4py.next import Dimension -from gt4py.next.ffront.decorator import Program -from gt4py.next.type_system.type_translation import from_type_hint +from gt4py.next.type_system import type_translation as gtx_type_translation from icon4py.tools.py2fgen.template import CffiPlugin, Func, FuncParameter from icon4py.tools.py2fgen.utils import parse_type_spec -class ImportStmtVisitor(ast.NodeVisitor): - """AST Visitor to extract import statements.""" - - def __init__(self): - self.import_statements: list[str] = [] - - def visit_Import(self, node): - for alias in node.names: - import_statement = f"import {alias.name}" + ( - f" as {alias.asname}" if alias.asname else "" - ) - self.import_statements.append(import_statement) - - def visit_ImportFrom(self, node): - for alias in node.names: - import_statement = f"from {node.module} import {alias.name}" + ( - f" as {alias.asname}" if alias.asname else "" - ) - self.import_statements.append(import_statement) - - -class TypeHintVisitor(ast.NodeVisitor): - """AST Visitor to extract function parameter type hints.""" - - def __init__(self): - self.type_hints: dict[str, str] = {} - - def visit_FunctionDef(self, node): - for arg in node.args.args: - if arg.annotation: - annotation = ast.unparse(arg.annotation) - self.type_hints[arg.arg] = annotation - else: - raise TypeError( - f"Missing type hint for parameter '{arg.arg}' in function '{node.name}'" - ) - - def parse(module_name: str, functions: list[str], plugin_name: str) -> CffiPlugin: module = importlib.import_module(module_name) - parsed_imports = _extract_import_statements(module) - - parsed_functions: list[Func] = [] - for f in functions: - parsed_functions.append(_parse_function(module, f)) + parsed_functions = [_parse_function(module, f) for f in functions] return CffiPlugin( module_name=module_name, plugin_name=plugin_name, functions=parsed_functions, - imports=parsed_imports, ) -def _extract_import_statements(module: ModuleType) -> list[str]: - src = inspect.getsource(module) - tree = ast.parse(src) - visitor = ImportStmtVisitor() - visitor.visit(tree) - return visitor.import_statements - - def _parse_function(module: ModuleType, function_name: str) -> Func: func = unwrap(getattr(module, function_name)) - is_gt4py_program = isinstance(func, Program) - type_hints = _extract_type_hint_strings(module, func, is_gt4py_program, function_name) + params = _parse_params(func) + return Func(name=function_name, args=params) - params = ( - _get_gt4py_func_params(func, type_hints) - if is_gt4py_program - else _get_simple_func_params(func, type_hints) - ) - return Func(name=function_name, args=params, is_gt4py_program=is_gt4py_program) - - -def _extract_type_hint_strings( - module: ModuleType, func: Callable, is_gt4py_program: bool, function_name: str -): - src = extract_function_signature( - inspect.getsource(module) if is_gt4py_program else inspect.getsource(func), function_name - ) - tree = ast.parse(src) - visitor = TypeHintVisitor() - visitor.visit(tree) - return visitor.type_hints - - -def extract_function_signature(code: str, function_name: str) -> str: - # This pattern attempts to match function definitions - pattern = rf"\bdef\s+{re.escape(function_name)}\s*\(([\s\S]*?)\)\s*:" - - match = re.search(pattern, code) - - if match: - # Constructing the full signature with empty return for ease of parsing by AST visitor - signature = match.group() - return signature.strip() + "\n return None" - else: - raise Exception(f"Could not parse function signature from the following code:\n {code}") - - -def _get_gt4py_func_params(func: Program, type_hints: dict[str, str]) -> List[FuncParameter]: - return [ - FuncParameter( - name=p.id, - d_type=parse_type_spec(p.type)[1], - dimensions=parse_type_spec(p.type)[0], - py_type_hint=type_hints[p.id], - ) - for p in func.past_stage.past_node.params - ] - - -def _get_simple_func_params(func: Callable, type_hints: dict[str, str]) -> List[FuncParameter]: +def _parse_params(func: Callable) -> List[FuncParameter]: sig_params = signature(func, follow_wrapped=False).parameters - return [ - FuncParameter( - name=s, - d_type=parse_type_spec(from_type_hint(param.annotation))[1], - dimensions=[ - Dimension(value=d.value) - for d in parse_type_spec(from_type_hint(param.annotation))[0] - ], - py_type_hint=type_hints.get(s, None), - ) - for s, param in sig_params.items() - ] + params = [] + for s, param in sig_params.items(): + gt4py_type = gtx_type_translation.from_type_hint(param.annotation) + dims, dtype = parse_type_spec(gt4py_type) + params.append(FuncParameter(name=s, d_type=dtype, dimensions=dims)) + + return params diff --git a/tools/src/icon4py/tools/py2fgen/plugin.py b/tools/src/icon4py/tools/py2fgen/plugin.py index 02256903b..8b30b8c6b 100644 --- a/tools/src/icon4py/tools/py2fgen/plugin.py +++ b/tools/src/icon4py/tools/py2fgen/plugin.py @@ -7,122 +7,16 @@ # SPDX-License-Identifier: BSD-3-Clause import logging -import math -import typing from pathlib import Path import cffi -import numpy as np -from cffi import FFI -from numpy.typing import NDArray from icon4py.tools.common.logger import setup_logger -if typing.TYPE_CHECKING: - import cupy as cp # type: ignore - -ffi = FFI() # needed for unpack and unpack_gpu functions - logger = setup_logger(__name__) -def unpack(ptr, *sizes: int) -> NDArray: - """ - Converts a C pointer into a NumPy array to directly manipulate memory allocated in Fortran. - This function is needed for operations requiring in-place modification of CPU data, enabling - changes made in Python to reflect immediately in the original Fortran memory space. - - Args: - ptr (CData): A CFFI pointer to the beginning of the data array in CPU memory. This pointer - should reference a contiguous block of memory whose total size matches the product - of the specified dimensions. - *sizes (int): Variable length argument list specifying the dimensions of the array. - These sizes determine the shape of the resulting NumPy array. - - Returns: - np.ndarray: A NumPy array that provides a direct view of the data pointed to by `ptr`. - This array shares the underlying data with the original Fortran code, allowing - modifications made through the array to affect the original data. - """ - length = math.prod(sizes) - c_type = ffi.getctype(ffi.typeof(ptr).item) - - # Map C data types to NumPy dtypes - dtype_map: dict[str, np.dtype] = { - "int": np.dtype(np.int32), - "double": np.dtype(np.float64), - } - dtype = dtype_map.get(c_type, np.dtype(c_type)) - - # Create a NumPy array from the buffer, specifying the Fortran order - arr = np.frombuffer(ffi.buffer(ptr, length * ffi.sizeof(c_type)), dtype=dtype).reshape( # type: ignore - sizes, order="F" - ) - return arr - - -def unpack_gpu(ptr, *sizes: int): - """ - Converts a C pointer into a CuPy array to directly manipulate memory allocated in Fortran. - This function is needed for operations that require in-place modification of GPU data, - enabling changes made in Python to reflect immediately in the original Fortran memory space. - - Args: - ptr (cffi.CData): A CFFI pointer to GPU memory allocated by OpenACC, representing - the starting address of the data. This pointer must correspond to - a contiguous block of memory whose total size matches the product - of the specified dimensions. - *sizes (int): Variable length argument list specifying the dimensions of the array. - These sizes determine the shape of the resulting CuPy array. - - Returns: - cp.ndarray: A CuPy array that provides a direct view of the data pointed to by `ptr`. - This array shares the underlying data with the original Fortran code, allowing - modifications made through the array to affect the original data. - """ - - if not sizes: - raise ValueError("Sizes must be provided to determine the array shape.") - - length = math.prod(sizes) - c_type = ffi.getctype(ffi.typeof(ptr).item) - - dtype_map = { - "int": cp.int32, - "double": cp.float64, - } - dtype = dtype_map.get(c_type, None) - if dtype is None: - raise ValueError(f"Unsupported C data type: {c_type}") - - itemsize = ffi.sizeof(c_type) - total_size = length * itemsize - - # cupy array from OpenACC device pointer - current_device = cp.cuda.Device() - ptr_val = int(ffi.cast("uintptr_t", ptr)) - mem = cp.cuda.UnownedMemory(ptr_val, total_size, owner=ptr, device_id=current_device.id) - memptr = cp.cuda.MemoryPointer(mem, 0) - arr = cp.ndarray(shape=sizes, dtype=dtype, memptr=memptr, order="F") - return arr - - -def int_array_to_bool_array(int_array: NDArray) -> NDArray: - """ - Converts a NumPy array of integers to a boolean array. - In the input array, 0 represents False, and any non-zero value (1 or -1) represents True. - - Args: - int_array: A NumPy array of integers. - - Returns: - A NumPy array of booleans. - """ - bool_array = int_array != 0 - return bool_array - - def generate_and_compile_cffi_plugin( plugin_name: str, c_header: str, python_wrapper: str, build_path: Path, backend: str ) -> None: diff --git a/tools/src/icon4py/tools/py2fgen/settings.py b/tools/src/icon4py/tools/py2fgen/settings.py index 09f4902b9..05c9d146f 100644 --- a/tools/src/icon4py/tools/py2fgen/settings.py +++ b/tools/src/icon4py/tools/py2fgen/settings.py @@ -9,9 +9,10 @@ import os from enum import Enum from functools import cached_property +from types import ModuleType import numpy as np -from gt4py.next import itir_python as run_roundtrip +from gt4py.next import backend as gtx_backend, itir_python as run_roundtrip from gt4py.next.program_processors.runners.gtfn import ( run_gtfn_cached, run_gtfn_gpu_cached, @@ -19,7 +20,7 @@ try: - import dace # type: ignore[import-not-found, import-untyped] + import dace # type: ignore[import-untyped] from gt4py.next.program_processors.runners.dace import ( run_dace_cpu, run_dace_cpu_noopt, @@ -33,6 +34,24 @@ dace: Optional[ModuleType] = None # type: ignore[no-redef] # definition needed here +def env_flag_to_bool(name: str, default: bool) -> bool: # copied from gt4py.next.config + """Recognize true or false signaling string values.""" + flag_value = None + if name in os.environ: + flag_value = os.environ[name].lower() + match flag_value: + case None: + return default + case "0" | "false" | "off": + return False + case "1" | "true" | "on": + return True + case _: + raise ValueError( + "Invalid environment flag value: use '0 | false | off' or '1 | true | on'." + ) + + class Device(Enum): CPU = "CPU" GPU = "GPU" @@ -51,7 +70,7 @@ class GT4PyBackend(Enum): @dataclasses.dataclass class Icon4PyConfig: @cached_property - def icon4py_backend(self): + def icon4py_backend(self) -> str: backend = os.environ.get("ICON4PY_BACKEND", "CPU") if hasattr(GT4PyBackend, backend): return backend @@ -62,12 +81,12 @@ def icon4py_backend(self): ) @cached_property - def icon4py_dace_orchestration(self): + def icon4py_dace_orchestration(self) -> bool: # Any value other than None will be considered as True - return os.environ.get("ICON4PY_DACE_ORCHESTRATION", None) + return env_flag_to_bool("ICON4PY_DACE_ORCHESTRATION", False) @cached_property - def array_ns(self): + def array_ns(self) -> ModuleType: if self.device == Device.GPU: import cupy as cp # type: ignore[import-not-found] @@ -76,8 +95,8 @@ def array_ns(self): return np @cached_property - def gt4py_runner(self): - backend_map = { + def gt4py_runner(self) -> gtx_backend.Backend: + backend_map: dict[str, gtx_backend.Backend] = { GT4PyBackend.CPU.name: run_gtfn_cached, GT4PyBackend.GPU.name: run_gtfn_gpu_cached, GT4PyBackend.ROUNDTRIP.name: run_roundtrip, @@ -92,7 +111,7 @@ def gt4py_runner(self): return backend_map[self.icon4py_backend] @cached_property - def device(self): + def device(self) -> Device: device_map = { GT4PyBackend.CPU.name: Device.CPU, GT4PyBackend.GPU.name: Device.GPU, @@ -109,12 +128,12 @@ def device(self): return device @cached_property - def limited_area(self): - return os.environ.get("ICON4PY_LAM", False) + def limited_area(self) -> bool: + return env_flag_to_bool("ICON4PY_LAM", False) @cached_property - def parallel_run(self): - return os.environ.get("ICON4PY_PARALLEL", False) + def parallel_run(self) -> bool: + return env_flag_to_bool("ICON4PY_PARALLEL", False) config = Icon4PyConfig() diff --git a/tools/src/icon4py/tools/py2fgen/template.py b/tools/src/icon4py/tools/py2fgen/template.py index b61e88b32..78c41bd49 100644 --- a/tools/src/icon4py/tools/py2fgen/template.py +++ b/tools/src/icon4py/tools/py2fgen/template.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import inspect +from types import ModuleType from typing import Any, Sequence from gt4py.eve import Node, datamodels @@ -19,7 +19,6 @@ BUILTIN_TO_ISO_C_TYPE, BUILTIN_TO_NUMPY_TYPE, ) -from icon4py.tools.py2fgen.plugin import int_array_to_bool_array, unpack, unpack_gpu from icon4py.tools.py2fgen.settings import GT4PyBackend from icon4py.tools.py2fgen.utils import flatten_and_get_unique_elts from icon4py.tools.py2fgen.wrappers import wrapper_dimension @@ -30,11 +29,11 @@ UNINITIALISED_ARRAYS = [ - "mask_hdiff", + "mask_hdiff", # optional diffusion init fields "zd_diffcoef", "zd_vertoffset", "zd_intcoef", - "hdef_ic", + "hdef_ic", # optional diffusion output fields "div_ic", "dwdx", "dwdy", @@ -53,14 +52,13 @@ class FuncParameter(Node): name: str d_type: ScalarKind dimensions: Sequence[Dimension] - py_type_hint: str size_args: list[str] = datamodels.field(init=False) is_array: bool = datamodels.field(init=False) gtdims: list[str] = datamodels.field(init=False) size_args_len: int = datamodels.field(init=False) np_type: str = datamodels.field(init=False) - def __post_init__(self): + def __post_init__(self) -> None: self.size_args = dims_to_size_strings(self.dimensions) self.size_args_len = len(self.size_args) self.is_array = True if len(self.dimensions) >= 1 else False @@ -80,7 +78,6 @@ def __post_init__(self): class Func(Node): name: str args: Sequence[FuncParameter] - is_gt4py_program: bool global_size_args: Sequence[str] = datamodels.field(init=False) def __post_init__(self, *args: Any, **kwargs: Any) -> None: @@ -92,7 +89,6 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None: class CffiPlugin(Node): module_name: str plugin_name: str - imports: list[str] functions: list[Func] @@ -102,19 +98,14 @@ class PythonWrapper(CffiPlugin): profile: bool limited_area: bool cffi_decorator: str = CFFI_DECORATOR - cffi_unpack: str = inspect.getsource(unpack) - cffi_unpack_gpu: str = inspect.getsource(unpack_gpu) - int_to_bool: str = inspect.getsource(int_array_to_bool_array) gt4py_backend: str = datamodels.field(init=False) - is_gt4py_program_present: bool = datamodels.field(init=False) def __post_init__(self, *args: Any, **kwargs: Any) -> None: self.gt4py_backend = GT4PyBackend[self.backend].value - self.is_gt4py_program_present = any(func.is_gt4py_program for func in self.functions) self.uninitialised_arrays = get_uninitialised_arrays(self.limited_area) -def get_uninitialised_arrays(limited_area: bool): +def get_uninitialised_arrays(limited_area: bool) -> list[str]: return UNINITIALISED_ARRAYS if not limited_area else [] @@ -124,7 +115,7 @@ def build_array_size_args() -> dict[str, str]: from icon4py.tools.py2fgen.wrappers import wrapper_dimension # Function to process the dimensions - def process_dimensions(module): + def process_dimensions(module: ModuleType) -> None: for var_name, var in vars(module).items(): if isinstance(var, Dimension): dim_name = var_name.replace( @@ -233,31 +224,19 @@ def render_fortran_array_sizes(param: FuncParameter) -> str: class PythonWrapperGenerator(TemplatedGenerator): + # TODO(havogt): put np_as_located_field logic into unpack PythonWrapper = as_jinja( """\ # imports for generated wrapper code import logging {% if _this_node.profile %}import time{% endif %} -import math from {{ plugin_name }} import ffi -import numpy as np {% if _this_node.backend == 'GPU' %}import cupy as cp {% endif %} -from numpy.typing import NDArray from gt4py.next.iterator.embedded import np_as_located_field from icon4py.tools.py2fgen.settings import config -xp = config.array_ns +from icon4py.tools.py2fgen import wrapper_utils from icon4py.model.common import dimension as dims -{% if _this_node.is_gt4py_program_present %} -# necessary imports when embedding a gt4py program directly -from gt4py.next import itir_python as run_roundtrip -from gt4py.next.program_processors.runners.gtfn import run_gtfn_cached, run_gtfn_gpu_cached -from icon4py.model.common.grid.simple import SimpleGrid - -# We need a grid to pass offset providers to the embedded gt4py program (granules load their own grid at runtime) -grid = SimpleGrid() -{% endif %} - # logger setup log_format = '%(asctime)s.%(msecs)03d - %(levelname)s - %(message)s' logging.basicConfig(level=logging.{%- if _this_node.debug_mode -%}DEBUG{%- else -%}ERROR{%- endif -%}, @@ -265,35 +244,20 @@ class PythonWrapperGenerator(TemplatedGenerator): datefmt='%Y-%m-%d %H:%M:%S') {% if _this_node.backend == 'GPU' %}logging.info(cp.show_config()) {% endif %} -import numpy as np - -# embedded module imports -{% for stmt in imports -%} -{{ stmt }} -{% endfor %} - # embedded function imports {% for func in _this_node.functions -%} from {{ module_name }} import {{ func.name }} {% endfor %} -{% if _this_node.backend == 'GPU' %} -{{ cffi_unpack_gpu }} -{% else %} -{{ cffi_unpack }} -{% endif %} - -{{ int_to_bool }} - {% for func in _this_node.functions %} {{ cffi_decorator }} def {{ func.name }}_wrapper( {%- for arg in func.args -%} -{{ arg.name }}: {{ arg.py_type_hint | replace("KHalfDim","KDim") }}{% if not loop.last or func.global_size_args %}, {% endif %} +{{ arg.name }}{% if not loop.last or func.global_size_args %}, {% endif %} {%- endfor %} {%- for arg in func.global_size_args -%} -{{ arg }}: gtx.int32{{ ", " if not loop.last else "" }} +{{ arg }}{{ ", " if not loop.last else "" }} {%- endfor -%} ): try: @@ -317,11 +281,11 @@ def {{ func.name }}_wrapper( {%- if arg.name in _this_node.uninitialised_arrays -%} {{ arg.name }} = xp.ones((1,) * {{ arg.size_args_len }}, dtype={{arg.np_type}}, order="F") {%- else -%} - {{ arg.name }} = unpack{%- if _this_node.backend == 'GPU' -%}_gpu{%- endif -%}({{ arg.name }}, {{ ", ".join(arg.size_args) }}) + {{ arg.name }} = wrapper_utils.unpack{%- if _this_node.backend == 'GPU' -%}_gpu{%- endif -%}(ffi, {{ arg.name }}, {{ ", ".join(arg.size_args) }}) {%- endif -%} {%- if arg.d_type.name == "BOOL" %} - {{ arg.name }} = int_array_to_bool_array({{ arg.name }}) + {{ arg.name }} = wrapper_utils.int_array_to_bool_array({{ arg.name }}) {%- endif %} {%- if _this_node.debug_mode %} @@ -368,14 +332,10 @@ def {{ func.name }}_wrapper( func_start_time = time.perf_counter() {% endif %} - {{ func.name }} - {%- if func.is_gt4py_program -%}.with_backend({{ _this_node.gt4py_backend }}){%- endif -%}( + {{ func.name }}( {%- for arg in func.args -%} - {{ arg.name }}{{ ", " if not loop.last or func.is_gt4py_program else "" }} + {{ arg.name }}{{ ", " if not loop.last else "" }} {%- endfor -%} - {%- if func.is_gt4py_program -%} - offset_provider=grid.offset_providers - {%- endif -%} ) {% if _this_node.profile %} @@ -421,7 +381,7 @@ class CHeaderGenerator(TemplatedGenerator): "extern int {{ name }}_wrapper({%- for arg in args -%}{{ arg }}{% if not loop.last or global_size_args|length > 0 %}, {% endif %}{% endfor -%}{%- for sarg in global_size_args -%} int {{ sarg }}{% if not loop.last %}, {% endif %}{% endfor -%});" ) - def visit_FuncParameter(self, param: FuncParameter): + def visit_FuncParameter(self, param: FuncParameter) -> str: return self.generic_visit( param, rendered_type=to_c_type(param.d_type), pointer=render_c_pointer(param) ) @@ -474,14 +434,12 @@ class F90Interface(Node): def __post_init__(self, *args: Any, **kwargs: Any) -> None: functions = self.cffi_plugin.functions self.function_declaration = [ - F90FunctionDeclaration(name=f.name, args=f.args, is_gt4py_program=f.is_gt4py_program) - for f in functions + F90FunctionDeclaration(name=f.name, args=f.args) for f in functions ] self.function_definition = [ F90FunctionDefinition( name=f.name, args=f.args, - is_gt4py_program=f.is_gt4py_program, limited_area=self.limited_area, ) for f in functions @@ -509,7 +467,7 @@ class F90InterfaceGenerator(TemplatedGenerator): """ ) - def visit_F90FunctionDeclaration(self, func: F90FunctionDeclaration, **kwargs): + def visit_F90FunctionDeclaration(self, func: F90FunctionDeclaration, **kwargs: Any) -> str: arg_names = ", &\n ".join(map(lambda x: x.name, func.args)) if func.global_size_args: arg_names += ",&\n" + ", &\n".join(func.global_size_args) @@ -530,7 +488,7 @@ def visit_F90FunctionDeclaration(self, func: F90FunctionDeclaration, **kwargs): """ ) - def visit_F90FunctionDefinition(self, func: F90FunctionDefinition, **kwargs): + def visit_F90FunctionDefinition(self, func: F90FunctionDefinition, **kwargs: Any) -> str: if len(func.args) < 1: arg_names, param_names_with_size_args = "", "" else: @@ -583,7 +541,7 @@ def visit_F90FunctionDefinition(self, func: F90FunctionDefinition, **kwargs): """ ) - def visit_FuncParameter(self, param: FuncParameter, **kwargs): + def visit_FuncParameter(self, param: FuncParameter, **kwargs: Any) -> str: return self.generic_visit( param, value=as_f90_value(param), diff --git a/tools/src/icon4py/tools/py2fgen/utils.py b/tools/src/icon4py/tools/py2fgen/utils.py index 05fbb9a43..0d8e80e84 100644 --- a/tools/src/icon4py/tools/py2fgen/utils.py +++ b/tools/src/icon4py/tools/py2fgen/utils.py @@ -27,7 +27,7 @@ def flatten_and_get_unique_elts(list_of_lists: list[list[str]]) -> list[str]: return sorted(set(item for sublist in list_of_lists for item in sublist)) -def get_local_test_grid(grid_folder: str): +def get_local_test_grid(grid_folder: str) -> str: test_folder = "testdata" module_spec = importlib.util.find_spec("icon4py.tools") @@ -43,7 +43,7 @@ def get_local_test_grid(grid_folder: str): ) -def get_icon_grid_loc(): +def get_icon_grid_loc() -> str: env_path = os.environ.get("ICON_GRID_LOC") if env_path is not None: return env_path @@ -53,7 +53,7 @@ def get_icon_grid_loc(): ) -def get_grid_filename(): +def get_grid_filename() -> str: env_path = os.environ.get("ICON_GRID_NAME") if env_path is not None: return env_path diff --git a/tools/src/icon4py/tools/py2fgen/wrapper_utils.py b/tools/src/icon4py/tools/py2fgen/wrapper_utils.py new file mode 100644 index 000000000..af7c47a9b --- /dev/null +++ b/tools/src/icon4py/tools/py2fgen/wrapper_utils.py @@ -0,0 +1,121 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import numpy as np + + +if TYPE_CHECKING: + import cffi + +try: + import cupy as cp # type: ignore +except ImportError: + cp = None + + +def unpack(ffi: cffi.FFI, ptr, *sizes: int) -> np.typing.NDArray: # type: ignore[no-untyped-def] # CData type not public? + """ + Converts a C pointer into a NumPy array to directly manipulate memory allocated in Fortran. + This function is needed for operations requiring in-place modification of CPU data, enabling + changes made in Python to reflect immediately in the original Fortran memory space. + + Args: + ffi (cffi.FFI): A CFFI FFI instance. + ptr (CData): A CFFI pointer to the beginning of the data array in CPU memory. This pointer + should reference a contiguous block of memory whose total size matches the product + of the specified dimensions. + *sizes (int): Variable length argument list specifying the dimensions of the array. + These sizes determine the shape of the resulting NumPy array. + + Returns: + np.ndarray: A NumPy array that provides a direct view of the data pointed to by `ptr`. + This array shares the underlying data with the original Fortran code, allowing + modifications made through the array to affect the original data. + """ + length = math.prod(sizes) + c_type = ffi.getctype(ffi.typeof(ptr).item) + + # Map C data types to NumPy dtypes + dtype_map: dict[str, np.dtype] = { + "int": np.dtype(np.int32), + "double": np.dtype(np.float64), + } + dtype = dtype_map.get(c_type, np.dtype(c_type)) + + # Create a NumPy array from the buffer, specifying the Fortran order + arr = np.frombuffer(ffi.buffer(ptr, length * ffi.sizeof(c_type)), dtype=dtype).reshape( # type: ignore + sizes, order="F" + ) + return arr + + +def unpack_gpu(ffi: cffi.FFI, ptr, *sizes: int): # type: ignore[no-untyped-def] # CData type not public? + """ + Converts a C pointer into a CuPy array to directly manipulate memory allocated in Fortran. + This function is needed for operations that require in-place modification of GPU data, + enabling changes made in Python to reflect immediately in the original Fortran memory space. + + Args: + ffi (cffi.FFI): A CFFI FFI instance. + ptr (cffi.CData): A CFFI pointer to GPU memory allocated by OpenACC, representing + the starting address of the data. This pointer must correspond to + a contiguous block of memory whose total size matches the product + of the specified dimensions. + *sizes (int): Variable length argument list specifying the dimensions of the array. + These sizes determine the shape of the resulting CuPy array. + + Returns: + cp.ndarray: A CuPy array that provides a direct view of the data pointed to by `ptr`. + This array shares the underlying data with the original Fortran code, allowing + modifications made through the array to affect the original data. + """ + + if not sizes: + raise ValueError("Sizes must be provided to determine the array shape.") + + length = math.prod(sizes) + c_type = ffi.getctype(ffi.typeof(ptr).item) + + dtype_map = { + "int": cp.int32, + "double": cp.float64, + } + dtype = dtype_map.get(c_type, None) + if dtype is None: + raise ValueError(f"Unsupported C data type: {c_type}") + + itemsize = ffi.sizeof(c_type) + total_size = length * itemsize + + # cupy array from OpenACC device pointer + current_device = cp.cuda.Device() + ptr_val = int(ffi.cast("uintptr_t", ptr)) + mem = cp.cuda.UnownedMemory(ptr_val, total_size, owner=ptr, device_id=current_device.id) + memptr = cp.cuda.MemoryPointer(mem, 0) + arr = cp.ndarray(shape=sizes, dtype=dtype, memptr=memptr, order="F") + return arr + + +def int_array_to_bool_array(int_array: np.typing.NDArray) -> np.typing.NDArray: + """ + Converts a NumPy array of integers to a boolean array. + In the input array, 0 represents False, and any non-zero value (1 or -1) represents True. + + Args: + int_array: A NumPy array of integers. + + Returns: + A NumPy array of booleans. + """ + bool_array = int_array != 0 + return bool_array diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/debug_utils.py b/tools/src/icon4py/tools/py2fgen/wrappers/debug_utils.py index a29c7390e..ccebd1175 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/debug_utils.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/debug_utils.py @@ -34,7 +34,7 @@ def print_grid_decomp_info( num_cells: int, num_edges: int, num_verts: int, -): +) -> None: log.info("icon_grid:cell_start%s", icon_grid._start_indices[CellDim]) log.info("icon_grid:cell_end:%s", icon_grid._end_indices[CellDim]) log.info("icon_grid:vert_start:%s", icon_grid._start_indices[VertexDim]) diff --git a/tools/tests/py2fgen/test_cffi.py b/tools/tests/py2fgen/test_cffi.py index 682c292a8..4da1c9e45 100644 --- a/tools/tests/py2fgen/test_cffi.py +++ b/tools/tests/py2fgen/test_cffi.py @@ -14,7 +14,8 @@ import pytest from cffi import FFI -from icon4py.tools.py2fgen.plugin import generate_and_compile_cffi_plugin, unpack +from icon4py.tools.py2fgen.plugin import generate_and_compile_cffi_plugin +from icon4py.tools.py2fgen.wrapper_utils import unpack @pytest.fixture @@ -34,7 +35,7 @@ def test_unpack_column_major(data, expected_result, ffi): rows, cols = expected_result.shape - result = unpack(ptr, rows, cols) + result = unpack(ffi, ptr, rows, cols) assert np.array_equal(result, expected_result) diff --git a/tools/tests/py2fgen/test_cli.py b/tools/tests/py2fgen/test_cli.py index 5615acd09..ddf0ca0c2 100644 --- a/tools/tests/py2fgen/test_cli.py +++ b/tools/tests/py2fgen/test_cli.py @@ -116,7 +116,7 @@ def compile_and_run_fortran( env.update(env_vars) fortran_result = run_fortran_executable(plugin_name, env) if expected_error_code == 0: - assert "passed" in fortran_result.stdout + assert "passed" in fortran_result.stdout, fortran_result.stderr else: assert "failed" in fortran_result.stdout except subprocess.CalledProcessError as e: @@ -127,7 +127,6 @@ def compile_and_run_fortran( "run_backend, extra_flags", [ ("CPU", ("-DUSE_SQUARE_FROM_FUNCTION",)), - ("CPU", ""), ], ) def test_py2fgen_compilation_and_execution_square_cpu( @@ -139,7 +138,7 @@ def test_py2fgen_compilation_and_execution_square_cpu( run_test_case( cli_runner, square_wrapper_module, - "square,square_from_function", + "square_from_function", "square_plugin", run_backend, samples_path, @@ -171,7 +170,13 @@ def test_py2fgen_python_error_propagation_to_fortran( @pytest.mark.parametrize( "function_name, plugin_name, test_name, run_backend, extra_flags", [ - ("square", "square_plugin", "test_square", "GPU", ("-acc", "-Minfo=acc")), + ( + "square_from_function", + "square_plugin", + "test_square", + "GPU", + ("-acc", "-Minfo=acc", "-DUSE_SQUARE_FROM_FUNCTION"), + ), ], ) def test_py2fgen_compilation_and_execution_gpu( diff --git a/tools/tests/py2fgen/test_codegen.py b/tools/tests/py2fgen/test_codegen.py index 735743c9a..e6a64c0b6 100644 --- a/tools/tests/py2fgen/test_codegen.py +++ b/tools/tests/py2fgen/test_codegen.py @@ -30,18 +30,14 @@ name="name", d_type=ScalarKind.FLOAT32, dimensions=[dims.CellDim, dims.KDim], - py_type_hint="Field[dims.CellDim, dims.KDim], float64]", ) field_1d = FuncParameter( name="name", d_type=ScalarKind.FLOAT32, dimensions=[dims.KDim], - py_type_hint="Field[dims.KDim], float64]", ) -simple_type = FuncParameter( - name="name", d_type=ScalarKind.FLOAT32, dimensions=[], py_type_hint="gtx.int32" -) +simple_type = FuncParameter(name="name", d_type=ScalarKind.FLOAT32, dimensions=[]) @pytest.mark.parametrize( @@ -54,15 +50,13 @@ def test_as_target(param, expected): foo = Func( name="foo", args=[ - FuncParameter(name="one", d_type=ScalarKind.INT32, dimensions=[], py_type_hint="gtx.int32"), + FuncParameter(name="one", d_type=ScalarKind.INT32, dimensions=[]), FuncParameter( name="two", d_type=ScalarKind.FLOAT64, dimensions=[dims.CellDim, dims.KDim], - py_type_hint="Field[dims.CellDim, dims.KDim], float64]", ), ], - is_gt4py_program=False, ) bar = Func( @@ -75,27 +69,21 @@ def test_as_target(param, expected): dims.CellDim, dims.KDim, ], - py_type_hint="Field[dims.CellDim, dims.KDim], float64]", ), - FuncParameter(name="two", d_type=ScalarKind.INT32, dimensions=[], py_type_hint="gtx.int32"), + FuncParameter(name="two", d_type=ScalarKind.INT32, dimensions=[]), ], - is_gt4py_program=False, ) def test_cheader_generation_for_single_function(): - plugin = CffiPlugin( - module_name="libtest", plugin_name="libtest_plugin", functions=[foo], imports=["import foo"] - ) + plugin = CffiPlugin(module_name="libtest", plugin_name="libtest_plugin", functions=[foo]) header = CHeaderGenerator.apply(plugin) assert header == "extern int foo_wrapper(int one, double* two, int n_Cell, int n_K);" def test_cheader_for_pointer_args(): - plugin = CffiPlugin( - module_name="libtest", plugin_name="libtest_plugin", functions=[bar], imports=["import bar"] - ) + plugin = CffiPlugin(module_name="libtest", plugin_name="libtest_plugin", functions=[bar]) header = CHeaderGenerator.apply(plugin) assert header == "extern int bar_wrapper(float* one, int two, int n_Cell, int n_K);" @@ -112,7 +100,6 @@ def dummy_plugin(): module_name="libtest", plugin_name="libtest_plugin", functions=[foo, bar], - imports=["import foo_module_x\nimport bar_module_y"], ) @@ -240,102 +227,37 @@ def test_python_wrapper(dummy_plugin): interface = generate_python_wrapper( dummy_plugin, "GPU", False, limited_area=True, profile=False ) - expected = ''' + expected = """ # imports for generated wrapper code import logging -import math + from libtest_plugin import ffi -import numpy as np import cupy as cp -from numpy.typing import NDArray from gt4py.next.iterator.embedded import np_as_located_field from icon4py.tools.py2fgen.settings import config -xp = config.array_ns +from icon4py.tools.py2fgen import wrapper_utils from icon4py.model.common import dimension as dims # logger setup -log_format = '%(asctime)s.%(msecs)03d - %(levelname)s - %(message)s' -logging.basicConfig(level=logging.ERROR, - format=log_format, - datefmt='%Y-%m-%d %H:%M:%S') +log_format = "%(asctime)s.%(msecs)03d - %(levelname)s - %(message)s" +logging.basicConfig(level=logging.ERROR, format=log_format, datefmt="%Y-%m-%d %H:%M:%S") logging.info(cp.show_config()) -import numpy as np - -# embedded module imports -import foo_module_x -import bar_module_y - # embedded function imports from libtest import foo from libtest import bar -def unpack_gpu(ptr, *sizes: int): - """ - Converts a C pointer into a CuPy array to directly manipulate memory allocated in Fortran. - This function is needed for operations that require in-place modification of GPU data, - enabling changes made in Python to reflect immediately in the original Fortran memory space. - - Args: - ptr (cffi.CData): A CFFI pointer to GPU memory allocated by OpenACC, representing - the starting address of the data. This pointer must correspond to - a contiguous block of memory whose total size matches the product - of the specified dimensions. - *sizes (int): Variable length argument list specifying the dimensions of the array. - These sizes determine the shape of the resulting CuPy array. - - Returns: - cp.ndarray: A CuPy array that provides a direct view of the data pointed to by `ptr`. - This array shares the underlying data with the original Fortran code, allowing - modifications made through the array to affect the original data. - """ - - if not sizes: - raise ValueError("Sizes must be provided to determine the array shape.") - - length = math.prod(sizes) - c_type = ffi.getctype(ffi.typeof(ptr).item) - - dtype_map = { - "int": cp.int32, - "double": cp.float64, - } - dtype = dtype_map.get(c_type, None) - if dtype is None: - raise ValueError(f"Unsupported C data type: {c_type}") - - itemsize = ffi.sizeof(c_type) - total_size = length * itemsize - - # cupy array from OpenACC device pointer - current_device = cp.cuda.Device() - ptr_val = int(ffi.cast("uintptr_t", ptr)) - mem = cp.cuda.UnownedMemory(ptr_val, total_size, owner=ptr, device_id=current_device.id) - memptr = cp.cuda.MemoryPointer(mem, 0) - arr = cp.ndarray(shape=sizes, dtype=dtype, memptr=memptr, order="F") - return arr - -def int_array_to_bool_array(int_array: NDArray) -> NDArray: - """ - Converts a NumPy array of integers to a boolean array. - In the input array, 0 represents False, and any non-zero value (1 or -1) represents True. - - Args: - int_array: A NumPy array of integers. - - Returns: - A NumPy array of booleans. - """ - bool_array = int_array != 0 - return bool_array @ffi.def_extern() -def foo_wrapper(one: gtx.int32, two: Field[dims.CellDim, dims.KDim], float64], n_Cell: gtx.int32, n_K: gtx.int32): +def foo_wrapper(one, two, n_Cell, n_K): try: + # Unpack pointers into Ndarrays - two = unpack_gpu(two, n_Cell, n_K) + + two = wrapper_utils.unpack_gpu(ffi, two, n_Cell, n_K) # Allocate GT4Py Fields + two = np_as_located_field(dims.CellDim, dims.KDim)(two) foo(one, two) @@ -346,13 +268,17 @@ def foo_wrapper(one: gtx.int32, two: Field[dims.CellDim, dims.KDim], float64], n return 0 + @ffi.def_extern() -def bar_wrapper(one: Field[dims.CellDim, dims.KDim], float64], two: gtx.int32, n_Cell: gtx.int32, n_K: gtx.int32): +def bar_wrapper(one, two, n_Cell, n_K): try: + # Unpack pointers into Ndarrays - one = unpack_gpu(one, n_Cell, n_K) + + one = wrapper_utils.unpack_gpu(ffi, one, n_Cell, n_K) # Allocate GT4Py Fields + one = np_as_located_field(dims.CellDim, dims.KDim)(one) bar(one, two) @@ -362,7 +288,7 @@ def bar_wrapper(one: Field[dims.CellDim, dims.KDim], float64], two: gtx.int32, n return 1 return 0 - ''' + """ assert compare_ignore_whitespace(interface, expected) diff --git a/tools/tests/py2fgen/test_parsing.py b/tools/tests/py2fgen/test_parsing.py index 481099d52..322e0ca57 100644 --- a/tools/tests/py2fgen/test_parsing.py +++ b/tools/tests/py2fgen/test_parsing.py @@ -6,11 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import ast - -import pytest - -from icon4py.tools.py2fgen.parsing import ImportStmtVisitor, TypeHintVisitor, parse +from icon4py.tools.py2fgen.parsing import parse from icon4py.tools.py2fgen.template import CffiPlugin @@ -28,26 +24,3 @@ def test_parse_functions_on_wrapper(): functions = ["diffusion_init", "diffusion_run"] plugin = parse(module_path, functions, "diffusion_plugin") assert isinstance(plugin, CffiPlugin) - - -def test_import_visitor(): - tree = ast.parse(source) - extractor = ImportStmtVisitor() - extractor.visit(tree) - expected_imports = ["import foo", "import bar"] - assert extractor.import_statements == expected_imports - - -def test_type_hint_visitor(): - tree = ast.parse(source) - visitor = TypeHintVisitor() - visitor.visit(tree) - expected_type_hints = {"x": "gtx.Field[gtx.Dims[EdgeDim, KDim], float64]", "y": "int"} - assert visitor.type_hints == expected_type_hints - - -def test_function_missing_type_hints(): - tree = ast.parse(source.replace(": int", "")) - visitor = TypeHintVisitor() - with pytest.raises(TypeError): - visitor.visit(tree)