Skip to content

Commit c989150

Browse files
committed
Add tests
1 parent 39f057a commit c989150

File tree

1 file changed

+34
-32
lines changed

1 file changed

+34
-32
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,23 @@
2828

2929
from .helper import get_queue_or_skip, skip_if_dtype_not_supported
3030

31+
_all_dtypes = [
32+
"b1",
33+
"i1",
34+
"u1",
35+
"i2",
36+
"u2",
37+
"i4",
38+
"u4",
39+
"i8",
40+
"u8",
41+
"f2",
42+
"f4",
43+
"f8",
44+
"c8",
45+
"c16",
46+
]
47+
3148

3249
@pytest.mark.parametrize(
3350
"shape",
@@ -150,6 +167,21 @@ def test_usm_ndarray_writable_flag_views():
150167
assert not a.imag.flags.writable
151168

152169

170+
@pytest.mark.parametrize("dt1", _all_dtypes)
171+
@pytest.mark.parametrize("dt2", _all_dtypes)
172+
def test_usm_ndarray_from_zero_sized_usm_ndarray(dt1, dt2):
173+
q = get_queue_or_skip()
174+
skip_if_dtype_not_supported(dt1, q)
175+
skip_if_dtype_not_supported(dt2, q)
176+
177+
x1 = dpt.ones((0,), dtype=dt1, sycl_queue=q)
178+
x2 = dpt.usm_ndarray(x1.shape, dtype=dt2, buffer=x1)
179+
assert x2.dtype == dt2
180+
assert x2.sycl_queue == q
181+
assert x2._pointer == x1._pointer
182+
assert x2.shape == x1.shape
183+
184+
153185
def test_usm_ndarray_from_usm_ndarray_readonly():
154186
get_queue_or_skip()
155187

@@ -161,20 +193,8 @@ def test_usm_ndarray_from_usm_ndarray_readonly():
161193

162194
@pytest.mark.parametrize(
163195
"dtype",
164-
[
165-
"u1",
166-
"i1",
167-
"u2",
168-
"i2",
169-
"u4",
170-
"i4",
171-
"u8",
172-
"i8",
173-
"f2",
174-
"f4",
175-
"f8",
176-
"c8",
177-
"c16",
196+
_all_dtypes
197+
+ [
178198
b"float32",
179199
dpt.dtype("d"),
180200
np.half,
@@ -1103,24 +1123,6 @@ def test_pyx_capi_check_constants():
11031123
assert cdouble_typenum == dpt.dtype(np.cdouble).num
11041124

11051125

1106-
_all_dtypes = [
1107-
"b1",
1108-
"i1",
1109-
"u1",
1110-
"i2",
1111-
"u2",
1112-
"i4",
1113-
"u4",
1114-
"i8",
1115-
"u8",
1116-
"f2",
1117-
"f4",
1118-
"f8",
1119-
"c8",
1120-
"c16",
1121-
]
1122-
1123-
11241126
@pytest.mark.parametrize(
11251127
"shape", [tuple(), (1,), (5,), (2, 3), (2, 3, 4), (2, 2, 2, 2, 2)]
11261128
)

0 commit comments

Comments
 (0)