Skip to content

Commit 84666fa

Browse files
FIX: fixing spmd tests utilities for dpctl inputs (#1999) (#2020)
* FIX: fixing spmd tests utilities for dpctl inputs (cherry picked from commit 18d0428) Co-authored-by: Samir Nasibli <samir.nasibli@intel.com>
1 parent ae3c3a5 commit 84666fa

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

sklearnex/tests/_utils_spmd.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,11 @@ def _assert_unordered_allclose(spmd_result, batch_result, localize=False, **kwar
146146
Raises:
147147
AssertionError: If results do not match.
148148
"""
149+
np_spmd_result = _as_numpy(spmd_result)
149150

150-
sorted_spmd_result = spmd_result[np.argsort(np.linalg.norm(spmd_result, axis=1))]
151+
sorted_spmd_result = np_spmd_result[
152+
np.argsort(np.linalg.norm(np_spmd_result, axis=1))
153+
]
151154
if localize:
152155
local_batch_result = _get_local_tensor(batch_result)
153156
sorted_batch_result = local_batch_result[
@@ -158,7 +161,7 @@ def _assert_unordered_allclose(spmd_result, batch_result, localize=False, **kwar
158161
np.argsort(np.linalg.norm(batch_result, axis=1))
159162
]
160163

161-
assert_allclose(_as_numpy(sorted_spmd_result), sorted_batch_result, **kwargs)
164+
assert_allclose(sorted_spmd_result, sorted_batch_result, **kwargs)
162165

163166

164167
def _assert_kmeans_labels_allclose(
@@ -179,7 +182,11 @@ def _assert_kmeans_labels_allclose(
179182
AssertionError: If clusters are not correctly assigned.
180183
"""
181184

185+
np_spmd_labels = _as_numpy(spmd_labels)
186+
np_spmd_centers = _as_numpy(spmd_centers)
182187
local_batch_labels = _get_local_tensor(batch_labels)
183188
assert_allclose(
184-
spmd_centers[_as_numpy(spmd_labels)], batch_centers[local_batch_labels], **kwargs
189+
np_spmd_centers[np_spmd_labels],
190+
batch_centers[local_batch_labels],
191+
**kwargs,
185192
)

0 commit comments

Comments
 (0)