Skip to content

Commit

Permalink
fix wrapped array pickling bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Feb 9, 2025
1 parent 5bd8e66 commit e09c1b4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
3 changes: 3 additions & 0 deletions python/cudf/cudf/pandas/_wrappers/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..fast_slow_proxy import (
_fast_slow_function_call,
_FastSlowAttribute,
_FinalProxy,
is_proxy_object,
make_final_proxy_type,
make_intermediate_proxy_type,
Expand Down Expand Up @@ -162,6 +163,8 @@ def ndarray__array_ufunc__(self, ufunc, method, *inputs, **kwargs):
slow_to_fast=lambda slow: cupy.asarray(slow).flat,
additional_attributes={
"__array__": array_method,
"__reduce__": _FinalProxy.__reduce__,
"__setstate__": _FinalProxy.__setstate__,
},
)

Expand Down
23 changes: 7 additions & 16 deletions python/cudf/cudf/pandas/fast_slow_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from enum import IntEnum
from typing import Any, Literal

import cupy as cp
import numpy as np

from rmm import RMMError
Expand Down Expand Up @@ -191,10 +190,7 @@ def _fsproxy_slow_to_fast(self):
# convert it to a fast one
if self._fsproxy_state is _State.SLOW:
return slow_to_fast(self._fsproxy_wrapped)
try:
return self._fsproxy_wrapped
except AttributeError:
return cp.asarray(self)
return self._fsproxy_wrapped

@nvtx.annotate(
"COPY_FAST_TO_SLOW",
Expand All @@ -206,10 +202,7 @@ def _fsproxy_fast_to_slow(self):
# convert it to a slow one
if self._fsproxy_state is _State.FAST:
return fast_to_slow(self._fsproxy_wrapped)
try:
return self._fsproxy_wrapped
except AttributeError:
return np.asarray(self)
return self._fsproxy_wrapped

def as_gpu_object(self):
return self._fsproxy_slow_to_fast()
Expand All @@ -219,13 +212,11 @@ def as_cpu_object(self):

@property # type: ignore
def _fsproxy_state(self) -> _State:
try:
is_fast_type = isinstance(
self._fsproxy_wrapped, self._fsproxy_fast_type
)
except AttributeError:
is_fast_type = isinstance(self, self._fsproxy_fast_type)
return _State.FAST if is_fast_type else _State.SLOW
return (
_State.FAST
if isinstance(self._fsproxy_wrapped, self._fsproxy_fast_type)
else _State.SLOW
)

slow_dir = dir(slow_type)
cls_dict = {
Expand Down

0 comments on commit e09c1b4

Please sign in to comment.