Skip to content

Commit

Permalink
FIX: update functional support fallback logic for a DPNP/DPCTL ndarra…
Browse files Browse the repository at this point in the history
…y inputs (#2113)

* FIX: update functional support fallback logic a little bit

* host numpy copies of the inputs data will be used for the fallback cases, since stock scikit-learn doesn't support DPCTL usm_ndarray and DPNP ndarray

* Added a clarifying comment

* Enhanced patch message for data transfer
  • Loading branch information
samir-nasibli authored Oct 22, 2024
1 parent 0809a3e commit c0eb5ad
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 12 deletions.
6 changes: 3 additions & 3 deletions onedal/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _transfer_to_host(queue, *data):
raise RuntimeError("Input data shall be located on single target device")

host_data.append(item)
return queue, host_data
return has_usm_data, queue, host_data


def _get_global_queue():
Expand All @@ -150,8 +150,8 @@ def _get_global_queue():

def _get_host_inputs(*args, **kwargs):
q = _get_global_queue()
q, hostargs = _transfer_to_host(q, *args)
q, hostvalues = _transfer_to_host(q, *kwargs.values())
_, q, hostargs = _transfer_to_host(q, *args)
_, q, hostvalues = _transfer_to_host(q, *kwargs.values())
hostkwargs = dict(zip(kwargs.keys(), hostvalues))
return q, hostargs, hostkwargs

Expand Down
19 changes: 14 additions & 5 deletions sklearnex/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,34 @@ def _get_backend(obj, queue, method_name, *data):

def dispatch(obj, method_name, branches, *args, **kwargs):
q = _get_global_queue()
q, hostargs = _transfer_to_host(q, *args)
q, hostvalues = _transfer_to_host(q, *kwargs.values())
has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args)
has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values())
hostkwargs = dict(zip(kwargs.keys(), hostvalues))

backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs)

has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs
if backend == "onedal":
patching_status.write_log(queue=q)
# Host args only used before onedal backend call.
# Device will be offloaded when onedal backend will be called.
patching_status.write_log(queue=q, transferred_to_host=False)
return branches[backend](obj, *hostargs, **hostkwargs, queue=q)
if backend == "sklearn":
if (
"array_api_dispatch" in get_config()
and get_config()["array_api_dispatch"]
and "array_api_support" in obj._get_tags()
and obj._get_tags()["array_api_support"]
and not has_usm_data
):
# USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is
# not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant,
# except for the linalg module. There is no guarantee that stock scikit-learn will
# work with such input data. The condition will be updated after DPNP.ndarray and
# DPCTL usm_ndarray enabling for conformance testing and these arrays supportance
# of the fallback cases.
# If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn,
# then raw inputs are used for the fallback.
patching_status.write_log()
patching_status.write_log(transferred_to_host=False)
return branches[backend](obj, *args, **kwargs)
else:
patching_status.write_log()
Expand Down
14 changes: 10 additions & 4 deletions sklearnex/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ class PatchingConditionsChain(daal4py_PatchingConditionsChain):
def get_status(self):
return self.patching_is_enabled

def write_log(self, queue=None):
def write_log(self, queue=None, transferred_to_host=True):
if self.patching_is_enabled:
self.logger.info(
f"{self.scope_name}: {get_patch_message('onedal', queue=queue)}"
f"{self.scope_name}: {get_patch_message('onedal', queue=queue, transferred_to_host=transferred_to_host)}"
)
else:
self.logger.debug(
Expand All @@ -43,7 +43,9 @@ def write_log(self, queue=None):
self.logger.debug(
f"{self.scope_name}: patching failed with cause - {message}"
)
self.logger.info(f"{self.scope_name}: {get_patch_message('sklearn')}")
self.logger.info(
f"{self.scope_name}: {get_patch_message('sklearn', transferred_to_host=transferred_to_host)}"
)


def set_sklearn_ex_verbose():
Expand All @@ -66,7 +68,7 @@ def set_sklearn_ex_verbose():
)


def get_patch_message(s, queue=None):
def get_patch_message(s, queue=None, transferred_to_host=True):
if s == "onedal":
message = "running accelerated version on "
if queue is not None:
Expand All @@ -87,6 +89,10 @@ def get_patch_message(s, queue=None):
f"Invalid input - expected one of 'onedal','sklearn',"
f" 'sklearn_after_onedal', got {s}"
)
if transferred_to_host:
message += (
". All input data transferred to host for further backend computations."
)
return message


Expand Down

0 comments on commit c0eb5ad

Please sign in to comment.