Skip to content

Commit

Permalink
Remove non-used functionality from py2fgen (#656)
Browse files Browse the repository at this point in the history
- Remove special handling for wrapping gtx.Program (if needed at some
point, the cleaner way would be to first wrap the program and then use
the standard py2fgen)
- Remove type annotations from the wrapper (they were wrong and required
extra logic to extract them)
- Remove copying over the imports from the wrapped module to the
wrapper, unclear why that was needed in the past
- Enable mypy on `icon4py.tools` and add type annotations
- Move `unpack` and `int_array_to_bool_array` to `wrapper_util` module
and import from there (they are now required at runtime).
  • Loading branch information
havogt authored Feb 5, 2025
1 parent 9bab5b1 commit 3ffe859
Show file tree
Hide file tree
Showing 20 changed files with 246 additions and 453 deletions.
22 changes: 11 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ repos:
- id: tach
name: Check inter-package dependencies

# TODO(egparedes): fix and activate mypy hook
# - repo: local
# hooks:
# - id: mypy
# name: mypy static type checker
# entry: bash -c 'mypy tools/src/icon4py/tools model/common/src/icon4py/model/common'
# language: system
# types_or: [python, pyi]
# pass_filenames: false
# require_serial: true
# stages: [pre-commit]

- repo: local
hooks:
- id: mypy
name: mypy static type checker
entry: bash -c 'mypy tools/src/icon4py/tools' # model/common/src/icon4py/model/common # TODO(egparedes): fix and activate mypy hook for all packages
language: system
types_or: [python, pyi]
pass_filenames: false
require_serial: true
stages: [pre-commit]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ disallow_untyped_defs = true
# '^tests/liskov/*.py',
# '^tests/py2f/*.py',
# ]
ignore_missing_imports = true
ignore_missing_imports = false
implicit_reexport = true
install_types = true
non_interactive = true
Expand Down
4 changes: 2 additions & 2 deletions tools/src/icon4py/tools/common/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class DummyConnectivity(Connectivity):
neighbor_axis: Dimension = Dimension("unused")
index_type: type[int] = int

def mapped_index(self, cur_index, neigh_index) -> int:
def mapped_index(self, cur_index, neigh_index) -> int: # type: ignore[no-untyped-def] # code will disappear with next gt4py version
raise AssertionError("Unreachable")
return 0

Expand Down Expand Up @@ -89,7 +89,7 @@ def _get_field_infos(fvprog: Program) -> dict[str, FieldInfo]:
assert all(
_is_list_of_names(body.args) for body in fvprog.past_stage.past_node.body
), "Found unsupported expression in input arguments."
input_arg_ids = set(arg.id for body in fvprog.past_stage.past_node.body for arg in body.args) # type: ignore[attr-defined] # Checked in the assert
input_arg_ids = set(arg.id for body in fvprog.past_stage.past_node.body for arg in body.args)

out_args = (body.kwargs["out"] for body in fvprog.past_stage.past_node.body)
output_fields = []
Expand Down
4 changes: 3 additions & 1 deletion tools/src/icon4py/tools/icon4pygen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class ModuleType(click.ParamType):
f"{dycore_import_path}.mo_velocity_advection_stencil_",
]

def shell_complete(self, ctx, param, incomplete):
def shell_complete(
self, ctx: click.Context, param: click.Parameter, incomplete: str
) -> list[click.shell_completion.CompletionItem]:
if len(incomplete) > 0 and incomplete.endswith(":"):
completions = [incomplete + incomplete[:-1].split(".")[-1]]
else:
Expand Down
4 changes: 3 additions & 1 deletion tools/src/icon4py/tools/liskov/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
logger = setup_logger(__name__)


def split_comma(ctx, param, value) -> Optional[tuple[str]]:
def split_comma(
ctx: click.Context, param: click.Parameter, value: str
) -> Optional[tuple[str, ...]]:
return tuple(v.strip() for v in value.split(",")) if value else None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class EndStencilStatement(EndBasicStencilStatement):
noprofile: Optional[bool]
noaccenddata: Optional[bool]

def __post_init__(self, *args, **kwargs) -> None:
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
all_fields = [Field(**asdict(f)) for f in self.stencil_data.fields]
self.bounds_fields = BoundsFields(**asdict(self.stencil_data.bounds))
self.name = self.stencil_data.name
Expand Down
2 changes: 1 addition & 1 deletion tools/src/icon4py/tools/liskov/parsing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _validate_stencil_directives(self, directives: Sequence[ts.ParsedDirective])
def _identify_unbalanced_directives(
directives: Sequence[ts.ParsedDirective],
directive_types: tuple[Type[ts.ParsedDirective], ...],
):
) -> None:
directive_counts: dict[str, int] = defaultdict(int)
for directive in directives:
if isinstance(directive, directive_types):
Expand Down
2 changes: 1 addition & 1 deletion tools/src/icon4py/tools/py2fgen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from icon4py.tools.py2fgen.settings import GT4PyBackend


def parse_comma_separated_list(ctx, param, value) -> list[str]:
def parse_comma_separated_list(ctx: click.Context, param: click.Parameter, value: str) -> list[str]:
# Splits the input string by commas and strips any leading/trailing whitespace from the strings
return [item.strip() for item in value.split(",")]

Expand Down
1 change: 0 additions & 1 deletion tools/src/icon4py/tools/py2fgen/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def generate_python_wrapper(
module_name=plugin.module_name,
plugin_name=plugin.plugin_name,
functions=plugin.functions,
imports=plugin.imports,
backend=backend,
debug_mode=debug_mode,
limited_area=limited_area,
Expand Down
131 changes: 12 additions & 119 deletions tools/src/icon4py/tools/py2fgen/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,147 +6,40 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import ast
import importlib
import inspect
import re
from inspect import signature, unwrap
from types import ModuleType
from typing import Callable, List

from gt4py.next import Dimension
from gt4py.next.ffront.decorator import Program
from gt4py.next.type_system.type_translation import from_type_hint
from gt4py.next.type_system import type_translation as gtx_type_translation

from icon4py.tools.py2fgen.template import CffiPlugin, Func, FuncParameter
from icon4py.tools.py2fgen.utils import parse_type_spec


class ImportStmtVisitor(ast.NodeVisitor):
"""AST Visitor to extract import statements."""

def __init__(self):
self.import_statements: list[str] = []

def visit_Import(self, node):
for alias in node.names:
import_statement = f"import {alias.name}" + (
f" as {alias.asname}" if alias.asname else ""
)
self.import_statements.append(import_statement)

def visit_ImportFrom(self, node):
for alias in node.names:
import_statement = f"from {node.module} import {alias.name}" + (
f" as {alias.asname}" if alias.asname else ""
)
self.import_statements.append(import_statement)


class TypeHintVisitor(ast.NodeVisitor):
"""AST Visitor to extract function parameter type hints."""

def __init__(self):
self.type_hints: dict[str, str] = {}

def visit_FunctionDef(self, node):
for arg in node.args.args:
if arg.annotation:
annotation = ast.unparse(arg.annotation)
self.type_hints[arg.arg] = annotation
else:
raise TypeError(
f"Missing type hint for parameter '{arg.arg}' in function '{node.name}'"
)


def parse(module_name: str, functions: list[str], plugin_name: str) -> CffiPlugin:
module = importlib.import_module(module_name)
parsed_imports = _extract_import_statements(module)

parsed_functions: list[Func] = []
for f in functions:
parsed_functions.append(_parse_function(module, f))
parsed_functions = [_parse_function(module, f) for f in functions]

return CffiPlugin(
module_name=module_name,
plugin_name=plugin_name,
functions=parsed_functions,
imports=parsed_imports,
)


def _extract_import_statements(module: ModuleType) -> list[str]:
src = inspect.getsource(module)
tree = ast.parse(src)
visitor = ImportStmtVisitor()
visitor.visit(tree)
return visitor.import_statements


def _parse_function(module: ModuleType, function_name: str) -> Func:
func = unwrap(getattr(module, function_name))
is_gt4py_program = isinstance(func, Program)
type_hints = _extract_type_hint_strings(module, func, is_gt4py_program, function_name)
params = _parse_params(func)
return Func(name=function_name, args=params)

params = (
_get_gt4py_func_params(func, type_hints)
if is_gt4py_program
else _get_simple_func_params(func, type_hints)
)

return Func(name=function_name, args=params, is_gt4py_program=is_gt4py_program)


def _extract_type_hint_strings(
module: ModuleType, func: Callable, is_gt4py_program: bool, function_name: str
):
src = extract_function_signature(
inspect.getsource(module) if is_gt4py_program else inspect.getsource(func), function_name
)
tree = ast.parse(src)
visitor = TypeHintVisitor()
visitor.visit(tree)
return visitor.type_hints


def extract_function_signature(code: str, function_name: str) -> str:
# This pattern attempts to match function definitions
pattern = rf"\bdef\s+{re.escape(function_name)}\s*\(([\s\S]*?)\)\s*:"

match = re.search(pattern, code)

if match:
# Constructing the full signature with empty return for ease of parsing by AST visitor
signature = match.group()
return signature.strip() + "\n return None"
else:
raise Exception(f"Could not parse function signature from the following code:\n {code}")


def _get_gt4py_func_params(func: Program, type_hints: dict[str, str]) -> List[FuncParameter]:
return [
FuncParameter(
name=p.id,
d_type=parse_type_spec(p.type)[1],
dimensions=parse_type_spec(p.type)[0],
py_type_hint=type_hints[p.id],
)
for p in func.past_stage.past_node.params
]


def _get_simple_func_params(func: Callable, type_hints: dict[str, str]) -> List[FuncParameter]:
def _parse_params(func: Callable) -> List[FuncParameter]:
sig_params = signature(func, follow_wrapped=False).parameters
return [
FuncParameter(
name=s,
d_type=parse_type_spec(from_type_hint(param.annotation))[1],
dimensions=[
Dimension(value=d.value)
for d in parse_type_spec(from_type_hint(param.annotation))[0]
],
py_type_hint=type_hints.get(s, None),
)
for s, param in sig_params.items()
]
params = []
for s, param in sig_params.items():
gt4py_type = gtx_type_translation.from_type_hint(param.annotation)
dims, dtype = parse_type_spec(gt4py_type)
params.append(FuncParameter(name=s, d_type=dtype, dimensions=dims))

return params
106 changes: 0 additions & 106 deletions tools/src/icon4py/tools/py2fgen/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,122 +7,16 @@
# SPDX-License-Identifier: BSD-3-Clause

import logging
import math
import typing
from pathlib import Path

import cffi
import numpy as np
from cffi import FFI
from numpy.typing import NDArray

from icon4py.tools.common.logger import setup_logger


if typing.TYPE_CHECKING:
import cupy as cp # type: ignore

ffi = FFI() # needed for unpack and unpack_gpu functions

logger = setup_logger(__name__)


def unpack(ptr, *sizes: int) -> NDArray:
"""
Converts a C pointer into a NumPy array to directly manipulate memory allocated in Fortran.
This function is needed for operations requiring in-place modification of CPU data, enabling
changes made in Python to reflect immediately in the original Fortran memory space.
Args:
ptr (CData): A CFFI pointer to the beginning of the data array in CPU memory. This pointer
should reference a contiguous block of memory whose total size matches the product
of the specified dimensions.
*sizes (int): Variable length argument list specifying the dimensions of the array.
These sizes determine the shape of the resulting NumPy array.
Returns:
np.ndarray: A NumPy array that provides a direct view of the data pointed to by `ptr`.
This array shares the underlying data with the original Fortran code, allowing
modifications made through the array to affect the original data.
"""
length = math.prod(sizes)
c_type = ffi.getctype(ffi.typeof(ptr).item)

# Map C data types to NumPy dtypes
dtype_map: dict[str, np.dtype] = {
"int": np.dtype(np.int32),
"double": np.dtype(np.float64),
}
dtype = dtype_map.get(c_type, np.dtype(c_type))

# Create a NumPy array from the buffer, specifying the Fortran order
arr = np.frombuffer(ffi.buffer(ptr, length * ffi.sizeof(c_type)), dtype=dtype).reshape( # type: ignore
sizes, order="F"
)
return arr


def unpack_gpu(ptr, *sizes: int):
"""
Converts a C pointer into a CuPy array to directly manipulate memory allocated in Fortran.
This function is needed for operations that require in-place modification of GPU data,
enabling changes made in Python to reflect immediately in the original Fortran memory space.
Args:
ptr (cffi.CData): A CFFI pointer to GPU memory allocated by OpenACC, representing
the starting address of the data. This pointer must correspond to
a contiguous block of memory whose total size matches the product
of the specified dimensions.
*sizes (int): Variable length argument list specifying the dimensions of the array.
These sizes determine the shape of the resulting CuPy array.
Returns:
cp.ndarray: A CuPy array that provides a direct view of the data pointed to by `ptr`.
This array shares the underlying data with the original Fortran code, allowing
modifications made through the array to affect the original data.
"""

if not sizes:
raise ValueError("Sizes must be provided to determine the array shape.")

length = math.prod(sizes)
c_type = ffi.getctype(ffi.typeof(ptr).item)

dtype_map = {
"int": cp.int32,
"double": cp.float64,
}
dtype = dtype_map.get(c_type, None)
if dtype is None:
raise ValueError(f"Unsupported C data type: {c_type}")

itemsize = ffi.sizeof(c_type)
total_size = length * itemsize

# cupy array from OpenACC device pointer
current_device = cp.cuda.Device()
ptr_val = int(ffi.cast("uintptr_t", ptr))
mem = cp.cuda.UnownedMemory(ptr_val, total_size, owner=ptr, device_id=current_device.id)
memptr = cp.cuda.MemoryPointer(mem, 0)
arr = cp.ndarray(shape=sizes, dtype=dtype, memptr=memptr, order="F")
return arr


def int_array_to_bool_array(int_array: NDArray) -> NDArray:
"""
Converts a NumPy array of integers to a boolean array.
In the input array, 0 represents False, and any non-zero value (1 or -1) represents True.
Args:
int_array: A NumPy array of integers.
Returns:
A NumPy array of booleans.
"""
bool_array = int_array != 0
return bool_array


def generate_and_compile_cffi_plugin(
plugin_name: str, c_header: str, python_wrapper: str, build_path: Path, backend: str
) -> None:
Expand Down
Loading

0 comments on commit 3ffe859

Please sign in to comment.