Skip to content

Commit 871eabc

Browse files
Support jax 0.4.27 in CI tests (#6885)
Use jax Array devices method instead of device
1 parent 865e9b1 commit 871eabc

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/test_formatting.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -627,13 +627,13 @@ def test_jax_formatter_device(self):
627627
device = jax.devices()[0]
628628
formatter = JaxFormatter(device=str(device))
629629
row = formatter.format_row(pa_table)
630-
assert row["a"].device() == device
631-
assert row["c"].device() == device
630+
assert row["a"].devices().pop() == device
631+
assert row["c"].devices().pop() == device
632632
col = formatter.format_column(pa_table)
633-
assert col.device() == device
633+
assert col.devices().pop() == device
634634
batch = formatter.format_batch(pa_table)
635-
assert batch["a"].device() == device
636-
assert batch["c"].device() == device
635+
assert batch["a"].devices().pop() == device
636+
assert batch["c"].devices().pop() == device
637637

638638

639639
class QueryTest(TestCase):

0 commit comments

Comments
 (0)