@@ -146,8 +146,11 @@ def _assert_unordered_allclose(spmd_result, batch_result, localize=False, **kwar
146
146
Raises:
147
147
AssertionError: If results do not match.
148
148
"""
149
+ np_spmd_result = _as_numpy (spmd_result )
149
150
150
- sorted_spmd_result = spmd_result [np .argsort (np .linalg .norm (spmd_result , axis = 1 ))]
151
+ sorted_spmd_result = np_spmd_result [
152
+ np .argsort (np .linalg .norm (np_spmd_result , axis = 1 ))
153
+ ]
151
154
if localize :
152
155
local_batch_result = _get_local_tensor (batch_result )
153
156
sorted_batch_result = local_batch_result [
@@ -158,7 +161,7 @@ def _assert_unordered_allclose(spmd_result, batch_result, localize=False, **kwar
158
161
np .argsort (np .linalg .norm (batch_result , axis = 1 ))
159
162
]
160
163
161
- assert_allclose (_as_numpy ( sorted_spmd_result ) , sorted_batch_result , ** kwargs )
164
+ assert_allclose (sorted_spmd_result , sorted_batch_result , ** kwargs )
162
165
163
166
164
167
def _assert_kmeans_labels_allclose (
@@ -179,7 +182,11 @@ def _assert_kmeans_labels_allclose(
179
182
AssertionError: If clusters are not correctly assigned.
180
183
"""
181
184
185
+ np_spmd_labels = _as_numpy (spmd_labels )
186
+ np_spmd_centers = _as_numpy (spmd_centers )
182
187
local_batch_labels = _get_local_tensor (batch_labels )
183
188
assert_allclose (
184
- spmd_centers [_as_numpy (spmd_labels )], batch_centers [local_batch_labels ], ** kwargs
189
+ np_spmd_centers [np_spmd_labels ],
190
+ batch_centers [local_batch_labels ],
191
+ ** kwargs ,
185
192
)
0 commit comments