diff --git a/src/awkward/_kernels.py b/src/awkward/_kernels.py index 28c38cf1db..537aaa750f 100644 --- a/src/awkward/_kernels.py +++ b/src/awkward/_kernels.py @@ -12,6 +12,7 @@ from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.typetracer import try_touch_data +from awkward._nplikes.virtual import materialize_if_virtual from awkward._typing import Protocol, TypeAlias KernelKeyType: TypeAlias = tuple # Tuple[str, Unpack[Tuple[metadata.dtype, ...]]] @@ -88,6 +89,8 @@ def _cast(cls, x, t): def __call__(self, *args) -> None: assert len(args) == len(self._impl.argtypes) + args = materialize_if_virtual(*args) + return self._impl( *(self._cast(x, t) for x, t in zip(args, self._impl.argtypes)) ) @@ -97,6 +100,8 @@ class JaxKernel(NumpyKernel): def __call__(self, *args) -> None: assert len(args) == len(self._impl.argtypes) + args = materialize_if_virtual(*args) + if not any(Jax.is_tracer_type(type(arg)) for arg in args): return super().__call__(*args) @@ -138,6 +143,8 @@ def _cast(self, x, type_): def __call__(self, *args) -> None: import awkward._connect.cuda as ak_cuda + args = materialize_if_virtual(*args) + cupy = ak_cuda.import_cupy("Awkward Arrays with CUDA") maxlength = self.max_length(args) grid, blocks = self.calc_grid(maxlength), self.calc_blocks(maxlength) diff --git a/src/awkward/_nplikes/numpy.py b/src/awkward/_nplikes/numpy.py index 3043e2c2ed..7157e3a769 100644 --- a/src/awkward/_nplikes/numpy.py +++ b/src/awkward/_nplikes/numpy.py @@ -8,6 +8,7 @@ from awkward._nplikes.dispatch import register_nplike from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.placeholder import PlaceholderArray +from awkward._nplikes.virtual import VirtualArray from awkward._typing import TYPE_CHECKING, Final, Literal if TYPE_CHECKING: @@ -48,7 +49,8 @@ def is_own_array_type(cls, type_: type) -> bool: return issubclass(type_, numpy.ndarray) def is_c_contiguous(self, x: NDArray | PlaceholderArray) -> bool: - if isinstance(x, PlaceholderArray): + # TODO: What should this do for virtual arrays? + if isinstance(x, (PlaceholderArray, VirtualArray)): return True else: return x.flags["C_CONTIGUOUS"] # type: ignore[attr-defined]