Skip to content

Commit e2c0469

Browse files
authored
Add implementation of dpnp.ndarray.data and dpnp.ndarray.data.ptr attributes (#2521)
This PR adds implementation of `dpnp.ndarray.data` and `dpnp.ndarray.data.ptr` attributes. The tests are updated to use the new attributes where applicable.
1 parent 463af22 commit e2c0469

File tree

14 files changed

+243
-58
lines changed

14 files changed

+243
-58
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
* Added implementation of `dpnp.ndarray.view` method [#2520](https://github.com/IntelPython/dpnp/pull/2520)
1414
* Added a new backend routine `syrk` from oneMKL to perform symmetric rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix [2509](https://github.com/IntelPython/dpnp/pull/2509)
1515
* Added `timeout-minutes` property to GitHub jobs [#2526](https://github.com/IntelPython/dpnp/pull/2526)
16+
* Added implementation of `dpnp.ndarray.data` and `dpnp.ndarray.data.ptr` attributes [#2521](https://github.com/IntelPython/dpnp/pull/2521)
1617

1718
### Changed
1819

dpnp/dpnp_array.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from dpctl.tensor._numpy_helper import AxisError
3030

3131
import dpnp
32+
import dpnp.memory as dpm
3233

3334

3435
def _get_unwrapped_index_key(key):
@@ -76,9 +77,12 @@ def __init__(
7677
order = "C"
7778

7879
if buffer is not None:
79-
buffer = dpnp.get_usm_ndarray(buffer)
80+
# expecting to have buffer as dpnp.ndarray and usm_ndarray,
81+
# or as USM memory allocation
82+
if isinstance(buffer, dpnp_array):
83+
buffer = buffer.get_array()
8084

81-
if dtype is None:
85+
if dtype is None and hasattr(buffer, "dtype"):
8286
dtype = buffer.dtype
8387
else:
8488
buffer = usm_type
@@ -1015,7 +1019,15 @@ def cumsum(self, axis=None, dtype=None, out=None):
10151019

10161020
return dpnp.cumsum(self, axis=axis, dtype=dtype, out=out)
10171021

1018-
# 'data',
1022+
@property
1023+
def data(self):
1024+
"""
1025+
Python object pointing to the start of USM memory allocation with the
1026+
array's data.
1027+
1028+
"""
1029+
1030+
return dpm.create_data(self._array_obj)
10191031

10201032
def diagonal(self, offset=0, axis1=0, axis2=1):
10211033
"""

dpnp/dpnp_iface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from dpnp.dpnp_array import dpnp_array
5454
from dpnp.fft import *
5555
from dpnp.linalg import *
56+
from dpnp.memory import *
5657
from dpnp.random import *
5758

5859
__all__ = [

dpnp/memory/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# -*- coding: utf-8 -*-
2+
# *****************************************************************************
3+
# Copyright (c) 2025, Intel Corporation
4+
# All rights reserved.
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
# - Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# - Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
18+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
24+
# THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
from ._memory import (
28+
MemoryUSMDevice,
29+
MemoryUSMHost,
30+
MemoryUSMShared,
31+
create_data,
32+
)
33+
34+
__all__ = [
35+
"MemoryUSMDevice",
36+
"MemoryUSMHost",
37+
"MemoryUSMShared",
38+
"create_data",
39+
]

dpnp/memory/_memory.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# -*- coding: utf-8 -*-
2+
# *****************************************************************************
3+
# Copyright (c) 2025, Intel Corporation
4+
# All rights reserved.
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
# - Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# - Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
18+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
24+
# THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
import dpctl.tensor as dpt
28+
from dpctl.memory import MemoryUSMDevice as DPCTLMemoryUSMDevice
29+
from dpctl.memory import MemoryUSMHost as DPCTLMemoryUSMHost
30+
from dpctl.memory import MemoryUSMShared as DPCTLMemoryUSMShared
31+
32+
33+
def _add_ptr_property(cls):
34+
_storage_attr = "_ptr"
35+
36+
@property
37+
def ptr(self):
38+
"""
39+
Returns USM pointer to the start of array (element with zero
40+
multi-index) encoded as integer.
41+
42+
"""
43+
44+
return getattr(self, _storage_attr, None)
45+
46+
@ptr.setter
47+
def ptr(self, value):
48+
setattr(self, _storage_attr, value)
49+
50+
cls.ptr = ptr
51+
return cls
52+
53+
54+
@_add_ptr_property
55+
class MemoryUSMDevice(DPCTLMemoryUSMDevice):
56+
pass
57+
58+
59+
@_add_ptr_property
60+
class MemoryUSMHost(DPCTLMemoryUSMHost):
61+
pass
62+
63+
64+
@_add_ptr_property
65+
class MemoryUSMShared(DPCTLMemoryUSMShared):
66+
pass
67+
68+
69+
def create_data(x):
70+
"""
71+
Create an instance of :class:`.MemoryUSMDevice`, :class:`.MemoryUSMHost`,
72+
or :class:`.MemoryUSMShared` class depending on the type of USM allocation.
73+
74+
Parameters
75+
----------
76+
x : usm_ndarray
77+
Input array of :class:`dpctl.tensor.usm_ndarray` type.
78+
79+
Returns
80+
-------
81+
out : {MemoryUSMDevice, MemoryUSMHost, MemoryUSMShared}
82+
A data object with a reference on USM memory.
83+
84+
"""
85+
86+
dispatch = {
87+
DPCTLMemoryUSMDevice: MemoryUSMDevice,
88+
DPCTLMemoryUSMHost: MemoryUSMHost,
89+
DPCTLMemoryUSMShared: MemoryUSMShared,
90+
}
91+
92+
if not isinstance(x, dpt.usm_ndarray):
93+
raise TypeError(
94+
f"An array must be any of supported type, but got {type(x)}"
95+
)
96+
usm_data = x.usm_data
97+
98+
cls = dispatch.get(type(usm_data), None)
99+
if cls:
100+
data = cls(usm_data)
101+
# `ptr` is expecting to point at the start of the array's data,
102+
# while `usm_data._pointer` is a pointer at the start of memory buffer
103+
data.ptr = x._pointer
104+
return data
105+
raise TypeError(f"Expected USM memory, but got {type(usm_data)}")

dpnp/tests/test_dlpack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_device(self):
7878
x = dpnp.arange(5)
7979
y = dpnp.from_dlpack(x, device=x.__dlpack_device__())
8080
assert x.device == y.device
81-
assert x.get_array()._pointer == y.get_array()._pointer
81+
assert x.data.ptr == y.data.ptr
8282

8383
def test_numpy_input(self):
8484
x = numpy.arange(10)

dpnp/tests/test_memory.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import dpctl.tensor as dpt
2+
import numpy
3+
import pytest
4+
5+
import dpnp
6+
import dpnp.memory as dpm
7+
8+
9+
class IntUsmData(dpt.usm_ndarray):
10+
"""Class that overrides `usm_data` property in `dpt.usm_ndarray`."""
11+
12+
@property
13+
def usm_data(self):
14+
return 1
15+
16+
17+
class TestCreateData:
18+
@pytest.mark.parametrize("x", [numpy.ones(4), dpnp.zeros(2)])
19+
def test_wrong_input_type(self, x):
20+
with pytest.raises(TypeError):
21+
dpm.create_data(x)
22+
23+
def test_wrong_usm_data(self):
24+
a = dpt.ones(10)
25+
d = IntUsmData(a.shape, buffer=a)
26+
27+
with pytest.raises(TypeError):
28+
dpm.create_data(d)

dpnp/tests/third_party/cupy/core_tests/test_dlpack.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ def __dlpack_device__(self):
4141
@pytest.mark.skip("toDlpack() and fromDlpack() are not supported")
4242
class TestDLPackConversion:
4343

44-
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
4544
@testing.for_all_dtypes(no_bool=False)
46-
def test_conversion(self, dtype):
45+
def test_conversion(self, dtype, recwarn):
4746
orig_array = _gen_array(dtype)
4847
tensor = orig_array.toDlpack()
4948
out_array = cupy.fromDlpack(tensor)
5049
testing.assert_array_equal(orig_array, out_array)
51-
assert orig_array.get_array()._pointer == out_array.get_array()._pointer
50+
testing.assert_array_equal(orig_array.data.ptr, out_array.data.ptr)
51+
for w in recwarn:
52+
assert issubclass(w.category, cupy.VisibleDeprecationWarning)
5253

5354

5455
class TestNewDLPackConversion:
@@ -82,7 +83,7 @@ def test_conversion(self, dtype):
8283
orig_array = _gen_array(dtype)
8384
out_array = cupy.from_dlpack(orig_array)
8485
testing.assert_array_equal(orig_array, out_array)
85-
assert orig_array.get_array()._pointer == out_array.get_array()._pointer
86+
testing.assert_array_equal(orig_array.data.ptr, out_array.data.ptr)
8687

8788
@pytest.mark.skip("no limitations in from_dlpack()")
8889
def test_from_dlpack_and_conv_errors(self):
@@ -121,7 +122,7 @@ def test_conversion_max_version(self, kwargs, versioned):
121122
)
122123

123124
testing.assert_array_equal(orig_array, out_array)
124-
assert orig_array.get_array()._pointer == out_array.get_array()._pointer
125+
testing.assert_array_equal(orig_array.data.ptr, out_array.data.ptr)
125126

126127
def test_conversion_device(self):
127128
orig_array = _gen_array("float32")
@@ -135,7 +136,7 @@ def test_conversion_device(self):
135136
)
136137

137138
testing.assert_array_equal(orig_array, out_array)
138-
assert orig_array.get_array()._pointer == out_array.get_array()._pointer
139+
testing.assert_array_equal(orig_array.data.ptr, out_array.data.ptr)
139140

140141
def test_conversion_bad_device(self):
141142
arr = _gen_array("float32")
@@ -212,9 +213,8 @@ def test_stream(self):
212213
out_array = dlp.from_dlpack_capsule(dltensor)
213214
out_array = cupy.from_dlpack(out_array, device=dst_s)
214215
testing.assert_array_equal(orig_array, out_array)
215-
assert (
216-
orig_array.get_array()._pointer
217-
== out_array.get_array()._pointer
216+
testing.assert_array_equal(
217+
orig_array.data.ptr, out_array.data.ptr
218218
)
219219

220220

@@ -267,12 +267,13 @@ def test_deleter2(self, pool, max_version):
267267
# assert pool.n_free_blocks() == 1
268268

269269
@pytest.mark.skip("toDlpack() and fromDlpack() are not supported")
270-
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
271-
def test_multiple_consumption_error(self):
270+
def test_multiple_consumption_error(self, recwarn):
272271
# Prevent segfault, see #3611
273272
array = cupy.empty(10)
274273
tensor = array.toDlpack()
275274
array2 = cupy.fromDlpack(tensor)
276275
with pytest.raises(ValueError) as e:
277276
array3 = cupy.fromDlpack(tensor)
278277
assert "consumed multiple times" in str(e.value)
278+
for w in recwarn:
279+
assert issubclass(w.category, cupy.VisibleDeprecationWarning)

dpnp/tests/third_party/cupy/core_tests/test_ndarray.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ def test_shape_not_integer(self):
4343

4444
def test_shape_int_with_strides(self):
4545
dummy = cupy.ndarray(3)
46-
a = cupy.ndarray(3, strides=(0,), buffer=dummy)
46+
a = cupy.ndarray(3, strides=(0,), buffer=dummy.data)
4747
assert a.shape == (3,)
4848
assert a.strides == (0,)
4949

5050
def test_memptr(self):
5151
a = cupy.arange(6).astype(numpy.float32).reshape((2, 3))
52-
memptr = a
52+
memptr = a.data
5353

5454
b = cupy.ndarray((2, 3), numpy.float32, memptr)
5555
testing.assert_array_equal(a, b)
@@ -62,7 +62,7 @@ def test_memptr(self):
6262
)
6363
def test_memptr_with_strides(self):
6464
buf = cupy.ndarray(20, numpy.uint8)
65-
memptr = buf
65+
memptr = buf.data
6666

6767
# self-overlapping strides
6868
a = cupy.ndarray((2, 3), numpy.float32, memptr, strides=(2, 1))
@@ -82,7 +82,9 @@ def test_strides_without_memptr(self):
8282

8383
def test_strides_is_given_and_order_is_ignored(self):
8484
buf = cupy.ndarray(20, numpy.uint8)
85-
a = cupy.ndarray((2, 3), numpy.float32, buf, strides=(2, 1), order="C")
85+
a = cupy.ndarray(
86+
(2, 3), numpy.float32, buf.data, strides=(2, 1), order="C"
87+
)
8688
assert a.strides == (2, 1)
8789

8890
@testing.with_requires("numpy>=1.19")

0 commit comments

Comments
 (0)