|
17 | 17 | import inspect
|
18 | 18 | import logging
|
19 | 19 | from functools import wraps
|
| 20 | +from operator import xor |
20 | 21 |
|
21 | 22 | import numpy as np
|
22 | 23 | from sklearn import get_config
|
23 | 24 |
|
24 | 25 | from ._config import _get_config
|
25 |
| -from .datatypes import copy_to_dpnp, copy_to_usm, dlpack_to_numpy, usm_to_numpy |
| 26 | +from .datatypes import copy_to_dpnp, copy_to_usm, dlpack_to_numpy |
26 | 27 | from .utils import _sycl_queue_manager as QM
|
27 |
| -from .utils._array_api import _asarray, _is_numpy_namespace |
| 28 | +from .utils._array_api import _asarray, _get_sycl_namespace, _is_numpy_namespace |
28 | 29 | from .utils._third_party import is_dpnp_ndarray
|
29 | 30 |
|
30 | 31 | logger = logging.getLogger("sklearnex")
|
@@ -62,23 +63,23 @@ def wrapper(self, *args, **kwargs):
|
62 | 63 |
|
63 | 64 |
|
64 | 65 | def _transfer_to_host(*data):
|
65 |
| - has_usm_data, has_host_data = False, False |
| 66 | + has_usm_data = None |
66 | 67 |
|
67 | 68 | host_data = []
|
68 | 69 | for item in data:
|
69 |
| - if usm_iface := getattr(item, "__sycl_usm_array_interface__", None): |
70 |
| - item = usm_to_numpy(item, usm_iface) |
71 |
| - has_usm_data = True |
| 70 | + if item is None: |
| 71 | + host_data.append(item) |
| 72 | + continue |
| 73 | + |
| 74 | + if usm_iface := hasattr(item, "__sycl_usm_array_interface__"): |
| 75 | + xp = item.__array_namespace__() |
| 76 | + item = xp.asnumpy(item) |
| 77 | + has_usm_data = has_usm_data or has_usm_data is None |
72 | 78 | elif not isinstance(item, np.ndarray) and (hasattr(item, "__dlpack_device__")):
|
73 | 79 | item = dlpack_to_numpy(item)
|
74 |
| - has_host_data = True |
75 |
| - else: |
76 |
| - has_host_data = True |
77 |
| - |
78 |
| - mismatch_host_item = usm_iface is None and item is not None and has_usm_data |
79 |
| - mismatch_usm_item = usm_iface is not None and has_host_data |
80 | 80 |
|
81 |
| - if mismatch_host_item or mismatch_usm_item: |
| 81 | + # set has_usm_data to boolean and use xor to see if they don't match |
| 82 | + if xor((has_usm_data := bool(has_usm_data)), usm_iface): |
82 | 83 | raise RuntimeError("Input data shall be located on single target device")
|
83 | 84 |
|
84 | 85 | host_data.append(item)
|
@@ -171,3 +172,27 @@ def wrapper_impl(*args, **kwargs):
|
171 | 172 | return result
|
172 | 173 |
|
173 | 174 | return wrapper_impl
|
| 175 | + |
| 176 | + |
| 177 | +def support_sycl_format(func): |
| 178 | + # This wrapper enables scikit-learn functions and methods to work with |
| 179 | + # all sycl data frameworks as they no longer support numpy implicit |
| 180 | + # conversion and must be manually converted. This is only necessary |
| 181 | + # when array API is supported but not active. |
| 182 | + |
| 183 | + @wraps(func) |
| 184 | + def wrapper(*args, **kwargs): |
| 185 | + if ( |
| 186 | + not get_config().get("array_api_dispatch", False) |
| 187 | + and _get_sycl_namespace(*args)[2] |
| 188 | + ): |
| 189 | + with QM.manage_global_queue(kwargs.get("queue"), *args): |
| 190 | + if inspect.isfunction(func) and "." in func.__qualname__: |
| 191 | + self, (args, kwargs) = args[0], _get_host_inputs(*args[1:], **kwargs) |
| 192 | + return func(self, *args, **kwargs) |
| 193 | + else: |
| 194 | + args, kwargs = _get_host_inputs(*args, **kwargs) |
| 195 | + return func(*args, **kwargs) |
| 196 | + return func(*args, **kwargs) |
| 197 | + |
| 198 | + return wrapper |
0 commit comments