Skip to content

Commit c1c4dce

Browse files
shoyerXarray-Beam authors
authored andcommitted
Use RangeSource inside a new ReadDatasets ptransform
This is equivalent to the existing xbeam.DatasetToChunks, except it uses the beam source API instead of writing custom PTransforms. Hopefully this will be a bit more efficient! I have not yet exposed this as a public API because I want to test it more, and have the freedom to adjust the API. Probably I'll hook it up to xbeam.Dataset.from_xarray() initially. PiperOrigin-RevId: 828074171
1 parent fc1d602 commit c1c4dce

File tree

3 files changed

+259
-87
lines changed

3 files changed

+259
-87
lines changed

xarray_beam/_src/core.py

Lines changed: 177 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from collections.abc import Hashable, Iterator, Mapping, Sequence, Set
1818
import contextlib
19-
from functools import cached_property
19+
import functools
2020
import itertools
2121
import math
2222
import pickle
@@ -27,14 +27,15 @@
2727
import immutabledict
2828
import numpy as np
2929
import xarray
30+
from xarray_beam._src import range_source
3031
from xarray_beam._src import threadmap
3132

3233

33-
T = TypeVar('T')
34+
T = TypeVar("T")
3435

3536

3637
def export(obj: T) -> T:
37-
obj.__module__ = 'xarray_beam'
38+
obj.__module__ = "xarray_beam"
3839
return obj
3940

4041

@@ -122,7 +123,6 @@ class Key:
122123
Key(indices={'x': 4}, vars={'bar'})
123124
>>> key.with_indices(x=5)
124125
Key(indices={'x': 5}, vars={'bar'})
125-
126126
"""
127127

128128
# pylint: disable=redefined-builtin
@@ -184,8 +184,8 @@ def with_indices(self, **indices: int | None) -> Key:
184184
"""Replace some indices with new values.
185185
186186
Args:
187-
**indices: indices to override (for integer values) or remove, with
188-
values of ``None``.
187+
**indices: indices to override (for integer values) or remove, with values
188+
of ``None``.
189189
190190
Returns:
191191
New Key with the specified indices.
@@ -421,49 +421,19 @@ def normalize_expanded_chunks(
421421
)
422422

423423

424-
@export
425-
class DatasetToChunks(beam.PTransform, Generic[DatasetOrDatasets]):
426-
"""Split one or more xarray.Datasets into keyed chunks."""
424+
class _DatasetToChunksBase(beam.PTransform, Generic[DatasetOrDatasets]):
425+
"""Base class for PTransforms that split Datasets into chunks."""
427426

428427
def __init__(
429428
self,
430429
dataset: DatasetOrDatasets,
431430
chunks: Mapping[str, int | tuple[int, ...]] | None = None,
432431
split_vars: bool = False,
433-
num_threads: int | None = None,
434-
shard_keys_threshold: int = 200_000,
435-
tasks_per_shard: int = 10_000,
436432
):
437-
"""Initialize DatasetToChunks.
438-
439-
Args:
440-
dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
441-
[xarray.Dataset, ...]) pairs.
442-
chunks: optional chunking scheme. Required if the dataset is *not* already
443-
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
444-
precedence over the existing chunks.
445-
split_vars: whether to split the dataset into separate records for each
446-
data variable or to keep all data variables together. This is
447-
recommended if you don't need to perform joint operations on different
448-
dataset variables and individual variable chunks are sufficiently large.
449-
num_threads: optional number of Dataset chunks to load in parallel per
450-
worker. More threads can increase throughput, but also increases memory
451-
usage and makes it harder for Beam runners to shard work. Note that each
452-
variable in a Dataset is already loaded in parallel, so this is most
453-
useful for Datasets with a small number of variables or when using
454-
split_vars=True.
455-
shard_keys_threshold: threshold at which to compute keys on Beam workers,
456-
rather than only on the host process. This is important for scaling
457-
pipelines to millions of tasks.
458-
tasks_per_shard: number of tasks to emit per shard. Only used if the
459-
number of tasks exceeds shard_keys_threshold.
460-
"""
433+
"""Initialize _DatasetToChunksBase."""
461434
self.dataset = dataset
462435
self._validate(dataset, split_vars)
463436
self.split_vars = split_vars
464-
self.num_threads = num_threads
465-
self.shard_keys_threshold = shard_keys_threshold
466-
self.tasks_per_shard = tasks_per_shard
467437

468438
if chunks is None:
469439
dask_chunks = self._first.chunks
@@ -489,15 +459,15 @@ def _datasets(self) -> list[xarray.Dataset]:
489459
return [self.dataset]
490460
return list(self.dataset) # pytype: disable=bad-return-type
491461

492-
@cached_property
462+
@functools.cached_property
493463
def expanded_chunks(self) -> dict[str, tuple[int, ...]]:
494464
return normalize_expanded_chunks(self.chunks, self._first.sizes) # pytype: disable=wrong-arg-types # always-use-property-annotation
495465

496-
@cached_property
466+
@functools.cached_property
497467
def offsets(self) -> dict[str, list[int]]:
498468
return _chunks_to_offsets(self.expanded_chunks)
499469

500-
@cached_property
470+
@functools.cached_property
501471
def offset_index(self) -> dict[str, dict[int, int]]:
502472
return compute_offset_index(self.offsets)
503473

@@ -542,7 +512,78 @@ def _task_count(self) -> int:
542512
total += int(np.prod(count_list))
543513
return total
544514

545-
@cached_property
515+
def _key_to_chunks(self, key: Key) -> tuple[Key, DatasetOrDatasets]:
516+
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
517+
with inc_timer_msec(self.__class__, "read-msec"):
518+
sizes = {
519+
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
520+
for dim, offset in key.offsets.items()
521+
}
522+
slices = offsets_to_slices(key.offsets, sizes)
523+
results = []
524+
for ds in self._datasets:
525+
dataset = ds if key.vars is None else ds[list(key.vars)]
526+
valid_slices = {k: v for k, v in slices.items() if k in dataset.dims}
527+
chunk = dataset.isel(valid_slices)
528+
# Load the data, using a separate thread for each variable
529+
num_threads = len(dataset)
530+
result = chunk.chunk().compute(num_workers=num_threads)
531+
results.append(result)
532+
533+
inc_counter(self.__class__, "read-chunks")
534+
inc_counter(
535+
self.__class__, "read-bytes", sum(result.nbytes for result in results)
536+
)
537+
538+
if isinstance(self.dataset, xarray.Dataset):
539+
return key, results[0]
540+
else:
541+
return key, results
542+
543+
544+
@export
545+
class DatasetToChunks(_DatasetToChunksBase):
546+
"""Split one or more xarray.Datasets into keyed chunks."""
547+
548+
def __init__(
549+
self,
550+
dataset: DatasetOrDatasets,
551+
chunks: Mapping[str, int | tuple[int, ...]] | None = None,
552+
split_vars: bool = False,
553+
num_threads: int | None = None,
554+
shard_keys_threshold: int = 200_000,
555+
tasks_per_shard: int = 10_000,
556+
):
557+
"""Initialize DatasetToChunks.
558+
559+
Args:
560+
dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
561+
[xarray.Dataset, ...]) pairs.
562+
chunks: optional chunking scheme. Required if the dataset is *not* already
563+
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
564+
precedence over the existing chunks.
565+
split_vars: whether to split the dataset into separate records for each
566+
data variable or to keep all data variables together. This is
567+
recommended if you don't need to perform joint operations on different
568+
dataset variables and individual variable chunks are sufficiently large.
569+
num_threads: optional number of Dataset chunks to load in parallel per
570+
worker. More threads can increase throughput, but also increases memory
571+
usage and makes it harder for Beam runners to shard work. Note that each
572+
variable in a Dataset is already loaded in parallel, so this is most
573+
useful for Datasets with a small number of variables or when using
574+
split_vars=True.
575+
shard_keys_threshold: threshold at which to compute keys on Beam workers,
576+
rather than only on the host process. This is important for scaling
577+
pipelines to millions of tasks.
578+
tasks_per_shard: number of tasks to emit per shard. Only used if the
579+
number of tasks exceeds shard_keys_threshold.
580+
"""
581+
super().__init__(dataset, chunks, split_vars)
582+
self.num_threads = num_threads
583+
self.shard_keys_threshold = shard_keys_threshold
584+
self.tasks_per_shard = tasks_per_shard
585+
586+
@functools.cached_property
546587
def sharded_dim(self) -> str | None:
547588
# We use the simple heuristic of only sharding inputs along the dimension
548589
# with the most chunks.
@@ -552,7 +593,7 @@ def sharded_dim(self) -> str | None:
552593
}
553594
return max(lengths, key=lengths.get) if lengths else None # pytype: disable=bad-return-type
554595

555-
@cached_property
596+
@functools.cached_property
556597
def shard_count(self) -> int | None:
557598
"""Determine the number of times to shard input keys."""
558599
task_count = self._task_count()
@@ -610,34 +651,6 @@ def _shard_inputs(self) -> list[tuple[int | None, str | None]]:
610651
inputs.append((None, name))
611652
return inputs # pytype: disable=bad-return-type # always-use-property-annotation
612653

613-
def _key_to_chunks(self, key: Key) -> Iterator[tuple[Key, DatasetOrDatasets]]:
614-
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
615-
with inc_timer_msec(self.__class__, "read-msec"):
616-
sizes = {
617-
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
618-
for dim, offset in key.offsets.items()
619-
}
620-
slices = offsets_to_slices(key.offsets, sizes)
621-
results = []
622-
for ds in self._datasets:
623-
dataset = ds if key.vars is None else ds[list(key.vars)]
624-
valid_slices = {k: v for k, v in slices.items() if k in dataset.dims}
625-
chunk = dataset.isel(valid_slices)
626-
# Load the data, using a separate thread for each variable
627-
num_threads = len(dataset)
628-
result = chunk.chunk().compute(num_workers=num_threads)
629-
results.append(result)
630-
631-
inc_counter(self.__class__, "read-chunks")
632-
inc_counter(
633-
self.__class__, "read-bytes", sum(result.nbytes for result in results)
634-
)
635-
636-
if isinstance(self.dataset, xarray.Dataset):
637-
yield key, results[0]
638-
else:
639-
yield key, results
640-
641654
def expand(self, pcoll):
642655
if self.shard_count is None:
643656
# Create all keys on the machine launching the Beam pipeline. This is
@@ -652,11 +665,102 @@ def expand(self, pcoll):
652665
| beam.Reshuffle()
653666
)
654667

655-
return key_pcoll | "KeyToChunks" >> threadmap.FlatThreadMap(
668+
return key_pcoll | "KeyToChunks" >> threadmap.ThreadMap(
656669
self._key_to_chunks, num_threads=self.num_threads
657670
)
658671

659672

673+
# TODO(shoyer): expose this function as a public API, after switching it to
674+
# generate Key objects using `indices` instead of `offsets`.
675+
class ReadDataset(_DatasetToChunksBase):
676+
"""Read chunks from an xarray.Dataset into a Beam pipeline.
677+
678+
This PTransform is a Beam "splittable DoFn", which means that it may be
679+
dynamically split by Beam runners into smaller chunks for efficient parallel
680+
execution.
681+
"""
682+
683+
def __init__(
684+
self,
685+
dataset: xarray.Dataset,
686+
chunks: Mapping[str, int | tuple[int, ...]] | None = None,
687+
split_vars: bool = False,
688+
):
689+
"""Initialize ReadDatasets.
690+
691+
Args:
692+
dataset: dataset to split into (Key, xarray.Dataset) chunks.
693+
chunks: optional chunking scheme. Required if the dataset is *not* already
694+
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
695+
precedence over the existing chunks.
696+
split_vars: whether to split the dataset into separate records for each
697+
data variable or to keep all data variables together. This is
698+
recommended if you don't need to perform joint operations on different
699+
dataset variables and individual variable chunks are sufficiently large.
700+
"""
701+
super().__init__(dataset, chunks, split_vars)
702+
703+
@functools.cached_property
704+
def _chunk_index_shapes(
705+
self,
706+
) -> list[tuple[str | None, tuple[str, ...], tuple[int, ...]]]:
707+
"""Calculate the shapes of indices for each chunk of the data.
708+
709+
The result here is a list of tuples of the form (name, dims, shape), where
710+
name is the name of the variable (or None if all variables are consolidated)
711+
and dims and shape are the dimensions along which the variable's chunk is
712+
indexed, and shape of that chunk in _indices_. For example, if the dataset
713+
had a variable `foo` with dimensions `('x', 'y')`, shape (10, 10) with
714+
chunks `{'x': 5, 'y': 2}`, then this function would return a corresponding
715+
list entry `('foo', ('x', 'y'), (2, 5))`.
716+
"""
717+
out = []
718+
if not self.split_vars:
719+
dims = sorted(self.expanded_chunks)
720+
shape = tuple(len(self.expanded_chunks[dim]) for dim in dims)
721+
out.append((None, dims, shape))
722+
else:
723+
for name, variable in self._first.items():
724+
dims = tuple(d for d in variable.dims if d in self.expanded_chunks)
725+
shape = tuple(len(self.expanded_chunks[dim]) for dim in dims)
726+
out.append((name, dims, shape))
727+
return out # pytype: disable=bad-return-type
728+
729+
@functools.cached_property
730+
def _cumulative_sizes(self) -> np.ndarray:
731+
var_sizes = [math.prod(shape) for _, _, shape in self._chunk_index_shapes]
732+
return np.cumsum([0] + var_sizes)
733+
734+
def _index_to_key(self, position: int) -> Key:
735+
assert 0 <= position < self._cumulative_sizes[-1]
736+
var_index = (
737+
np.searchsorted(self._cumulative_sizes, position, side="right") - 1
738+
)
739+
offset = position - self._cumulative_sizes[var_index]
740+
name, dims, shape = self._chunk_index_shapes[var_index]
741+
indices = np.unravel_index(offset, shape)
742+
offsets = {dim: self.offsets[dim][idx] for dim, idx in zip(dims, indices)}
743+
return Key(offsets, vars=None if name is None else {name})
744+
745+
def _get_element(self, position: int) -> tuple[Key, xarray.Dataset]:
746+
return self._key_to_chunks(self._index_to_key(position)) # pytype: disable=bad-return-type
747+
748+
def expand(
749+
self, pbegin: beam.PBegin
750+
) -> beam.PCollection[tuple[Key, xarray.Dataset]]:
751+
element_count = self._task_count()
752+
assert element_count > 0
753+
# For simplicity, assume that all chunks are approximately the same size,
754+
# even if variables are being split and some variables have different
755+
# variables. This assumption could be relaxed in the future, with an
756+
# improved version of RangeSource.
757+
avg_chunk_bytes = math.ceil(self._first.nbytes / element_count)
758+
source = range_source.RangeSource(
759+
element_count, avg_chunk_bytes, self._get_element
760+
)
761+
return pbegin | beam.io.Read(source)
762+
763+
660764
def _ensure_chunk_is_computed(key: Key, dataset: xarray.Dataset) -> None:
661765
"""Ensure that a dataset contains no chunked variables."""
662766
for var_name, variable in dataset.variables.items():

0 commit comments

Comments
 (0)