diff --git a/bench/write_shard.py b/bench/write_shard.py new file mode 100644 index 0000000000..99e64a0deb --- /dev/null +++ b/bench/write_shard.py @@ -0,0 +1,66 @@ +import itertools +import os.path +import shutil +import sys +import tempfile +import timeit + +import line_profiler +import numpy as np + +import zarr +import zarr.codecs +import zarr.codecs.sharding + +if __name__ == "__main__": + sys.path.insert(0, "..") + + # setup + with tempfile.TemporaryDirectory() as path: + + ndim = 3 + opt = { + 'shape': [1024]*ndim, + 'chunks': [128]*ndim, + 'shards': [512]*ndim, + 'dtype': np.float64, + } + + store = zarr.storage.LocalStore(path) + z = zarr.create_array(store, **opt) + print(z) + + def cleanup() -> None: + for elem in os.listdir(path): + elem = os.path.join(path, elem) + if not elem.endswith(".json"): + if os.path.isdir(elem): + shutil.rmtree(elem) + else: + os.remove(elem) + + def write() -> None: + wchunk = [512]*ndim + nwchunks = [n//s for n, s in zip(opt['shape'], wchunk, strict=True)] + for shard in itertools.product(*(range(n) for n in nwchunks)): + slicer = tuple( + slice(i*n, (i+1)*n) + for i, n in zip(shard, wchunk, strict=True) + ) + d = np.random.rand(*wchunk).astype(opt['dtype']) + z[slicer] = d + + print("*" * 79) + + # time + vars = {"write": write, "cleanup": cleanup, "z": z, "opt": opt} + t = timeit.repeat("write()", "cleanup()", repeat=2, number=1, globals=vars) + print(t) + print(min(t)) + print(z) + + # profile + # f = zarr.codecs.sharding.ShardingCodec._encode_partial_single + # profile = line_profiler.LineProfiler(f) + # profile.run("write()") + # profile.print_stats() diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 42b1313fac..1507cac74a 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -202,7 +202,7 @@ def create_empty( buffer_prototype = default_buffer_prototype() index = _ShardIndex.create_empty(chunks_per_shard) obj = cls() - obj.buf = buffer_prototype.buffer.create_zero_length() + obj.buf = buffer_prototype.buffer.Delayed.create_zero_length() obj.index = index return obj @@ -251,7 +251,7 @@ def create_empty( if buffer_prototype is None: buffer_prototype = default_buffer_prototype() obj = cls() - obj.buf = buffer_prototype.buffer.create_zero_length() + obj.buf = buffer_prototype.buffer.Delayed.create_zero_length() obj.index = _ShardIndex.create_empty(chunks_per_shard) return obj diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index ccab103e0f..9ab4cb6bf3 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -502,6 +502,153 @@ class BufferPrototype(NamedTuple): nd_buffer: type[NDBuffer] +class DelayedBuffer(Buffer): + """ + A Buffer that is the virtual concatenation of other buffers. + """ + _BufferImpl: type + _concatenate: callable + + def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None: + if array is None: + self._data_list = [] + elif isinstance(array, list): + self._data_list = list(array) + else: + self._data_list = [array] + for array in self._data_list: + if array.ndim != 1: + raise ValueError("array: only 1-dim allowed") + if array.dtype != np.dtype("b"): + raise ValueError("array: only byte dtype allowed") + + @property + def _data(self) -> npt.NDArray[Any]: + return type(self)._concatenate(self._data_list) + + @classmethod + def from_buffer(cls, buffer: Buffer) -> Self: + if isinstance(buffer, cls): + return cls(buffer._data_list) + else: + return cls(buffer._data) + + def __add__(self, other: Buffer) -> Self: + if isinstance(other, self.__class__): + return self.__class__(self._data_list + other._data_list) + else: + return self.__class__(self._data_list + [other._data]) + + def __radd__(self, other: Buffer) -> Self: + if isinstance(other, self.__class__): + return self.__class__(other._data_list + self._data_list) + else: + return self.__class__([other._data] + self._data_list) + + def __len__(self) -> int: + return sum(map(len, self._data_list)) + + def __getitem__(self, key: slice) -> Self: + check_item_key_is_1d_contiguous(key) + start, stop = key.start, key.stop + this_len = len(self) + if start is None: + start = 0 + if start < 0: + start = this_len + start + if stop is None: + stop = this_len + if stop < 0: + stop = this_len + stop + if stop > this_len: + stop = this_len + if stop <= start: + return Buffer.from_buffer(b'') + + new_list = [] + offset = 0 + found_last = False + for chunk in self._data_list: + chunk_size = len(chunk) + skip = False + if 0 <= start - offset < chunk_size: + # first chunk + if stop - offset <= chunk_size: + # also last chunk + chunk = chunk[start-offset:stop-offset] + found_last = True + else: + chunk = chunk[start-offset:] + elif 0 <= stop - offset <= chunk_size: + # last chunk + chunk = chunk[:stop-offset] + found_last = True + elif chunk_size <= start - offset: + # before first chunk + skip = True + else: + # middle chunk + pass + + if not skip: + new_list.append(chunk) + if found_last: + break + offset += chunk_size + assert sum(map(len, new_list)) == stop - start + return self.__class__(new_list) + + def __setitem__(self, key: slice, value: Any) -> None: + # This assumes that `value` is a broadcasted array + check_item_key_is_1d_contiguous(key) + start, stop = key.start, key.stop + if start is None: + start = 0 + if start < 0: + start = len(self) + start + if stop is None: + stop = len(self) + if stop < 0: + stop = len(self) + stop + if stop <= start: + return + + offset = 0 + found_last = False + value = memoryview(np.asanyarray(value)) + for chunk in self._data_list: + chunk_size = len(chunk) + skip = False + if 0 <= start - offset < chunk_size: + # first chunk + if stop - offset <= chunk_size: + # also last chunk + chunk = chunk[start-offset:stop-offset] + found_last = True + else: + chunk = chunk[start-offset:] + elif 0 <= stop - offset <= chunk_size: + # last chunk + chunk = chunk[:stop-offset] + found_last = True + elif chunk_size <= start - offset: + # before first chunk + skip = True + else: + # middle chunk + pass + + if not skip: + chunk[:] = value[:len(chunk)] + value = value[len(chunk):] + if len(value) == 0: + # nothing left to write + break + if found_last: + break + offset += chunk_size + + # The default buffer prototype used throughout the Zarr codebase. def default_buffer_prototype() -> BufferPrototype: from zarr.registry import ( diff --git a/src/zarr/core/buffer/cpu.py b/src/zarr/core/buffer/cpu.py index 225adb6f5c..651c8c1796 100644 --- a/src/zarr/core/buffer/cpu.py +++ b/src/zarr/core/buffer/cpu.py @@ -185,6 +185,39 @@ def __setitem__(self, key: Any, value: Any) -> None: self._data.__setitem__(key, value) +class DelayedBuffer(core.DelayedBuffer, Buffer): + """ + A Buffer that is the virtual concatenation of other buffers. + """ + _BufferImpl = Buffer + _concatenate = np.concatenate + + def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None: + core.DelayedBuffer.__init__(self, array) + self._data_list = list(map(np.asanyarray, self._data_list)) + + @classmethod + def create_zero_length(cls) -> Self: + return cls(np.array([], dtype="b")) + + @classmethod + def from_buffer(cls, buffer: core.Buffer) -> Self: + if isinstance(buffer, cls): + return cls(buffer._data_list) + else: + return cls(buffer._data) + + @classmethod + def from_bytes(cls, bytes_like: BytesLike) -> Self: + return cls(np.asarray(bytes_like, dtype="b")) + + def as_numpy_array(self) -> npt.NDArray[Any]: + return np.asanyarray(self._data) + + +Buffer.Delayed = DelayedBuffer + + def as_numpy_array_wrapper( func: Callable[[npt.NDArray[Any]], bytes], buf: core.Buffer, prototype: core.BufferPrototype ) -> core.Buffer: diff --git a/src/zarr/core/buffer/gpu.py b/src/zarr/core/buffer/gpu.py index aac6792cff..333f6440a5 100644 --- a/src/zarr/core/buffer/gpu.py +++ b/src/zarr/core/buffer/gpu.py @@ -218,6 +218,39 @@ def __setitem__(self, key: Any, value: Any) -> None: self._data.__setitem__(key, value) +class DelayedBuffer(core.DelayedBuffer, Buffer): + """ + A Buffer that is the virtual concatenation of other buffers. + """ + _BufferImpl = Buffer + _concatenate = getattr(cp, 'concatenate', None) + + def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None: + core.DelayedBuffer.__init__(self, array) + self._data_list = list(map(cp.asarray, self._data_list)) + + @classmethod + def create_zero_length(cls) -> Self: + return cls(np.array([], dtype="b")) + + @classmethod + def from_buffer(cls, buffer: core.Buffer) -> Self: + if isinstance(buffer, cls): + return cls(buffer._data_list) + else: + return cls(buffer._data) + + @classmethod + def from_bytes(cls, bytes_like: BytesLike) -> Self: + return cls(np.asarray(bytes_like, dtype="b")) + + def as_numpy_array(self) -> npt.NDArray[Any]: + return np.asanyarray(self._data) + + +Buffer.Delayed = DelayedBuffer + + buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) register_buffer(Buffer)