1616
1717from collections .abc import Hashable , Iterator , Mapping , Sequence , Set
1818import contextlib
19- from functools import cached_property
19+ import functools
2020import itertools
2121import math
2222import pickle
2727import immutabledict
2828import numpy as np
2929import xarray
30+ from xarray_beam ._src import range_source
3031from xarray_beam ._src import threadmap
3132
3233
33- T = TypeVar ('T' )
34+ T = TypeVar ("T" )
3435
3536
3637def export (obj : T ) -> T :
37- obj .__module__ = ' xarray_beam'
38+ obj .__module__ = " xarray_beam"
3839 return obj
3940
4041
@@ -122,7 +123,6 @@ class Key:
122123 Key(indices={'x': 4}, vars={'bar'})
123124 >>> key.with_indices(x=5)
124125 Key(indices={'x': 5}, vars={'bar'})
125-
126126 """
127127
128128 # pylint: disable=redefined-builtin
@@ -184,8 +184,8 @@ def with_indices(self, **indices: int | None) -> Key:
184184 """Replace some indices with new values.
185185
186186 Args:
187- **indices: indices to override (for integer values) or remove, with
188- values of ``None``.
187+ **indices: indices to override (for integer values) or remove, with values
188+ of ``None``.
189189
190190 Returns:
191191 New Key with the specified indices.
@@ -421,49 +421,19 @@ def normalize_expanded_chunks(
421421)
422422
423423
424- @export
425- class DatasetToChunks (beam .PTransform , Generic [DatasetOrDatasets ]):
426- """Split one or more xarray.Datasets into keyed chunks."""
424+ class _DatasetToChunksBase (beam .PTransform , Generic [DatasetOrDatasets ]):
425+ """Base class for PTransforms that split Datasets into chunks."""
427426
428427 def __init__ (
429428 self ,
430429 dataset : DatasetOrDatasets ,
431430 chunks : Mapping [str , int | tuple [int , ...]] | None = None ,
432431 split_vars : bool = False ,
433- num_threads : int | None = None ,
434- shard_keys_threshold : int = 200_000 ,
435- tasks_per_shard : int = 10_000 ,
436432 ):
437- """Initialize DatasetToChunks.
438-
439- Args:
440- dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
441- [xarray.Dataset, ...]) pairs.
442- chunks: optional chunking scheme. Required if the dataset is *not* already
443- chunked. If the dataset *is* already chunked with Dask, `chunks` takes
444- precedence over the existing chunks.
445- split_vars: whether to split the dataset into separate records for each
446- data variable or to keep all data variables together. This is
447- recommended if you don't need to perform joint operations on different
448- dataset variables and individual variable chunks are sufficiently large.
449- num_threads: optional number of Dataset chunks to load in parallel per
450- worker. More threads can increase throughput, but also increases memory
451- usage and makes it harder for Beam runners to shard work. Note that each
452- variable in a Dataset is already loaded in parallel, so this is most
453- useful for Datasets with a small number of variables or when using
454- split_vars=True.
455- shard_keys_threshold: threshold at which to compute keys on Beam workers,
456- rather than only on the host process. This is important for scaling
457- pipelines to millions of tasks.
458- tasks_per_shard: number of tasks to emit per shard. Only used if the
459- number of tasks exceeds shard_keys_threshold.
460- """
433+ """Initialize _DatasetToChunksBase."""
461434 self .dataset = dataset
462435 self ._validate (dataset , split_vars )
463436 self .split_vars = split_vars
464- self .num_threads = num_threads
465- self .shard_keys_threshold = shard_keys_threshold
466- self .tasks_per_shard = tasks_per_shard
467437
468438 if chunks is None :
469439 dask_chunks = self ._first .chunks
@@ -489,15 +459,15 @@ def _datasets(self) -> list[xarray.Dataset]:
489459 return [self .dataset ]
490460 return list (self .dataset ) # pytype: disable=bad-return-type
491461
492- @cached_property
462+ @functools . cached_property
493463 def expanded_chunks (self ) -> dict [str , tuple [int , ...]]:
494464 return normalize_expanded_chunks (self .chunks , self ._first .sizes ) # pytype: disable=wrong-arg-types # always-use-property-annotation
495465
496- @cached_property
466+ @functools . cached_property
497467 def offsets (self ) -> dict [str , list [int ]]:
498468 return _chunks_to_offsets (self .expanded_chunks )
499469
500- @cached_property
470+ @functools . cached_property
501471 def offset_index (self ) -> dict [str , dict [int , int ]]:
502472 return compute_offset_index (self .offsets )
503473
@@ -542,7 +512,78 @@ def _task_count(self) -> int:
542512 total += int (np .prod (count_list ))
543513 return total
544514
545- @cached_property
515+ def _key_to_chunks (self , key : Key ) -> tuple [Key , DatasetOrDatasets ]:
516+ """Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
517+ with inc_timer_msec (self .__class__ , "read-msec" ):
518+ sizes = {
519+ dim : self .expanded_chunks [dim ][self .offset_index [dim ][offset ]]
520+ for dim , offset in key .offsets .items ()
521+ }
522+ slices = offsets_to_slices (key .offsets , sizes )
523+ results = []
524+ for ds in self ._datasets :
525+ dataset = ds if key .vars is None else ds [list (key .vars )]
526+ valid_slices = {k : v for k , v in slices .items () if k in dataset .dims }
527+ chunk = dataset .isel (valid_slices )
528+ # Load the data, using a separate thread for each variable
529+ num_threads = len (dataset )
530+ result = chunk .chunk ().compute (num_workers = num_threads )
531+ results .append (result )
532+
533+ inc_counter (self .__class__ , "read-chunks" )
534+ inc_counter (
535+ self .__class__ , "read-bytes" , sum (result .nbytes for result in results )
536+ )
537+
538+ if isinstance (self .dataset , xarray .Dataset ):
539+ return key , results [0 ]
540+ else :
541+ return key , results
542+
543+
544+ @export
545+ class DatasetToChunks (_DatasetToChunksBase ):
546+ """Split one or more xarray.Datasets into keyed chunks."""
547+
548+ def __init__ (
549+ self ,
550+ dataset : DatasetOrDatasets ,
551+ chunks : Mapping [str , int | tuple [int , ...]] | None = None ,
552+ split_vars : bool = False ,
553+ num_threads : int | None = None ,
554+ shard_keys_threshold : int = 200_000 ,
555+ tasks_per_shard : int = 10_000 ,
556+ ):
557+ """Initialize DatasetToChunks.
558+
559+ Args:
560+ dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
561+ [xarray.Dataset, ...]) pairs.
562+ chunks: optional chunking scheme. Required if the dataset is *not* already
563+ chunked. If the dataset *is* already chunked with Dask, `chunks` takes
564+ precedence over the existing chunks.
565+ split_vars: whether to split the dataset into separate records for each
566+ data variable or to keep all data variables together. This is
567+ recommended if you don't need to perform joint operations on different
568+ dataset variables and individual variable chunks are sufficiently large.
569+ num_threads: optional number of Dataset chunks to load in parallel per
570+ worker. More threads can increase throughput, but also increases memory
571+ usage and makes it harder for Beam runners to shard work. Note that each
572+ variable in a Dataset is already loaded in parallel, so this is most
573+ useful for Datasets with a small number of variables or when using
574+ split_vars=True.
575+ shard_keys_threshold: threshold at which to compute keys on Beam workers,
576+ rather than only on the host process. This is important for scaling
577+ pipelines to millions of tasks.
578+ tasks_per_shard: number of tasks to emit per shard. Only used if the
579+ number of tasks exceeds shard_keys_threshold.
580+ """
581+ super ().__init__ (dataset , chunks , split_vars )
582+ self .num_threads = num_threads
583+ self .shard_keys_threshold = shard_keys_threshold
584+ self .tasks_per_shard = tasks_per_shard
585+
586+ @functools .cached_property
546587 def sharded_dim (self ) -> str | None :
547588 # We use the simple heuristic of only sharding inputs along the dimension
548589 # with the most chunks.
@@ -552,7 +593,7 @@ def sharded_dim(self) -> str | None:
552593 }
553594 return max (lengths , key = lengths .get ) if lengths else None # pytype: disable=bad-return-type
554595
555- @cached_property
596+ @functools . cached_property
556597 def shard_count (self ) -> int | None :
557598 """Determine the number of times to shard input keys."""
558599 task_count = self ._task_count ()
@@ -610,34 +651,6 @@ def _shard_inputs(self) -> list[tuple[int | None, str | None]]:
610651 inputs .append ((None , name ))
611652 return inputs # pytype: disable=bad-return-type # always-use-property-annotation
612653
613- def _key_to_chunks (self , key : Key ) -> Iterator [tuple [Key , DatasetOrDatasets ]]:
614- """Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
615- with inc_timer_msec (self .__class__ , "read-msec" ):
616- sizes = {
617- dim : self .expanded_chunks [dim ][self .offset_index [dim ][offset ]]
618- for dim , offset in key .offsets .items ()
619- }
620- slices = offsets_to_slices (key .offsets , sizes )
621- results = []
622- for ds in self ._datasets :
623- dataset = ds if key .vars is None else ds [list (key .vars )]
624- valid_slices = {k : v for k , v in slices .items () if k in dataset .dims }
625- chunk = dataset .isel (valid_slices )
626- # Load the data, using a separate thread for each variable
627- num_threads = len (dataset )
628- result = chunk .chunk ().compute (num_workers = num_threads )
629- results .append (result )
630-
631- inc_counter (self .__class__ , "read-chunks" )
632- inc_counter (
633- self .__class__ , "read-bytes" , sum (result .nbytes for result in results )
634- )
635-
636- if isinstance (self .dataset , xarray .Dataset ):
637- yield key , results [0 ]
638- else :
639- yield key , results
640-
641654 def expand (self , pcoll ):
642655 if self .shard_count is None :
643656 # Create all keys on the machine launching the Beam pipeline. This is
@@ -652,11 +665,102 @@ def expand(self, pcoll):
652665 | beam .Reshuffle ()
653666 )
654667
655- return key_pcoll | "KeyToChunks" >> threadmap .FlatThreadMap (
668+ return key_pcoll | "KeyToChunks" >> threadmap .ThreadMap (
656669 self ._key_to_chunks , num_threads = self .num_threads
657670 )
658671
659672
673+ # TODO(shoyer): expose this function as a public API, after switching it to
674+ # generate Key objects using `indices` instead of `offsets`.
675+ class ReadDataset (_DatasetToChunksBase ):
676+ """Read chunks from an xarray.Dataset into a Beam pipeline.
677+
678+ This PTransform is a Beam "splittable DoFn", which means that it may be
679+ dynamically split by Beam runners into smaller chunks for efficient parallel
680+ execution.
681+ """
682+
683+ def __init__ (
684+ self ,
685+ dataset : xarray .Dataset ,
686+ chunks : Mapping [str , int | tuple [int , ...]] | None = None ,
687+ split_vars : bool = False ,
688+ ):
689+ """Initialize ReadDatasets.
690+
691+ Args:
692+ dataset: dataset to split into (Key, xarray.Dataset) chunks.
693+ chunks: optional chunking scheme. Required if the dataset is *not* already
694+ chunked. If the dataset *is* already chunked with Dask, `chunks` takes
695+ precedence over the existing chunks.
696+ split_vars: whether to split the dataset into separate records for each
697+ data variable or to keep all data variables together. This is
698+ recommended if you don't need to perform joint operations on different
699+ dataset variables and individual variable chunks are sufficiently large.
700+ """
701+ super ().__init__ (dataset , chunks , split_vars )
702+
703+ @functools .cached_property
704+ def _chunk_index_shapes (
705+ self ,
706+ ) -> list [tuple [str | None , tuple [str , ...], tuple [int , ...]]]:
707+ """Calculate the shapes of indices for each chunk of the data.
708+
709+ The result here is a list of tuples of the form (name, dims, shape), where
710+ name is the name of the variable (or None if all variables are consolidated)
711+ and dims and shape are the dimensions along which the variable's chunk is
712+ indexed, and shape of that chunk in _indices_. For example, if the dataset
713+ had a variable `foo` with dimensions `('x', 'y')`, shape (10, 10) with
714+ chunks `{'x': 5, 'y': 2}`, then this function would return a corresponding
715+ list entry `('foo', ('x', 'y'), (2, 5))`.
716+ """
717+ out = []
718+ if not self .split_vars :
719+ dims = sorted (self .expanded_chunks )
720+ shape = tuple (len (self .expanded_chunks [dim ]) for dim in dims )
721+ out .append ((None , dims , shape ))
722+ else :
723+ for name , variable in self ._first .items ():
724+ dims = tuple (d for d in variable .dims if d in self .expanded_chunks )
725+ shape = tuple (len (self .expanded_chunks [dim ]) for dim in dims )
726+ out .append ((name , dims , shape ))
727+ return out # pytype: disable=bad-return-type
728+
729+ @functools .cached_property
730+ def _cumulative_sizes (self ) -> np .ndarray :
731+ var_sizes = [math .prod (shape ) for _ , _ , shape in self ._chunk_index_shapes ]
732+ return np .cumsum ([0 ] + var_sizes )
733+
734+ def _index_to_key (self , position : int ) -> Key :
735+ assert 0 <= position < self ._cumulative_sizes [- 1 ]
736+ var_index = (
737+ np .searchsorted (self ._cumulative_sizes , position , side = "right" ) - 1
738+ )
739+ offset = position - self ._cumulative_sizes [var_index ]
740+ name , dims , shape = self ._chunk_index_shapes [var_index ]
741+ indices = np .unravel_index (offset , shape )
742+ offsets = {dim : self .offsets [dim ][idx ] for dim , idx in zip (dims , indices )}
743+ return Key (offsets , vars = None if name is None else {name })
744+
745+ def _get_element (self , position : int ) -> tuple [Key , xarray .Dataset ]:
746+ return self ._key_to_chunks (self ._index_to_key (position )) # pytype: disable=bad-return-type
747+
748+ def expand (
749+ self , pbegin : beam .PBegin
750+ ) -> beam .PCollection [tuple [Key , xarray .Dataset ]]:
751+ element_count = self ._task_count ()
752+ assert element_count > 0
753+ # For simplicity, assume that all chunks are approximately the same size,
754+ # even if variables are being split and some variables have different
755+ # variables. This assumption could be relaxed in the future, with an
756+ # improved version of RangeSource.
757+ avg_chunk_bytes = math .ceil (self ._first .nbytes / element_count )
758+ source = range_source .RangeSource (
759+ element_count , avg_chunk_bytes , self ._get_element
760+ )
761+ return pbegin | beam .io .Read (source )
762+
763+
660764def _ensure_chunk_is_computed (key : Key , dataset : xarray .Dataset ) -> None :
661765 """Ensure that a dataset contains no chunked variables."""
662766 for var_name , variable in dataset .variables .items ():
0 commit comments