3535import math
3636import operator
3737import os .path
38+ import re
3839import tempfile
3940import textwrap
4041import types
@@ -364,25 +365,28 @@ def _concat_labels(label1: str | None, label2: str) -> str:
364365def _whole_dataset_method (method_name : str ):
365366 """Helper function for defining a method with a fast-path for lazy data."""
366367
367- def method (self : Dataset , * args , ** kwargs ) -> Dataset :
368+ def method (
369+ self : Dataset , * args , label : str | None = None , ** kwargs
370+ ) -> Dataset :
371+ if label is None :
372+ label = _get_label (method_name )
373+
368374 func = operator .methodcaller (method_name , * args , ** kwargs )
369375 template = zarr .make_template (func (self .template ))
370376 chunks = {k : v for k , v in self .chunks .items () if k in template .dims }
371377
372- label = _get_label (method_name )
373-
374- pipeline , ptransform = _split_lazy_pcollection (self ._ptransform )
375- if isinstance (ptransform , core .DatasetToChunks ):
378+ pipeline , old_ptransform = _split_lazy_pcollection (self ._ptransform )
379+ if isinstance (old_ptransform , core .DatasetToChunks ):
376380 # Some transformations (e.g., indexing) can be applied much less
377381 # expensively to xarray.Dataset objects rather than via Xarray-Beam. Try
378382 # to preserve this option for downstream transformations if possible.
379- dataset = func (ptransform .dataset )
383+ dataset = func (old_ptransform .dataset )
380384 ptransform = core .DatasetToChunks (dataset , chunks , self .split_vars )
381- ptransform .label = _concat_labels (ptransform .label , label )
385+ ptransform .label = _concat_labels (old_ptransform .label , label )
382386 if pipeline is not None :
383387 ptransform = _LazyPCollection (pipeline , ptransform )
384388 else :
385- ptransform = self . ptransform | label >> beam .MapTuple (
389+ ptransform = old_ptransform | label >> beam .MapTuple (
386390 functools .partial (
387391 _apply_to_each_chunk , func , method_name , self .chunks , chunks
388392 )
@@ -558,6 +562,7 @@ def from_ptransform(
558562 template : xarray .Dataset ,
559563 chunks : Mapping [str | types .EllipsisType , int ],
560564 split_vars : bool = False ,
565+ label : str | None = None ,
561566 ) -> Dataset :
562567 """Create an xarray_beam.Dataset from a Beam PTransform.
563568
@@ -584,10 +589,15 @@ def from_ptransform(
584589 except for the last chunk in each dimension, which may be smaller.
585590 split_vars: A boolean indicating whether the chunks in ``ptransform`` are
586591 split across variables, or if each chunk contains all variables.
592+ label: A unique name for this stage of the pipeline. Defaults to ``None``,
593+ in which case a name will be generated.
587594
588595 Returns:
589596 An ``xarray_beam.Dataset`` instance wrapping the PTransform.
590597 """
598+ if label is None :
599+ label = _get_label ('from_ptransform' )
600+
591601 if not isinstance (chunks , Mapping ):
592602 raise TypeError (
593603 f'chunks must be a mapping for from_ptransform, got { chunks } '
@@ -598,8 +608,9 @@ def from_ptransform(
598608 'chunks must be a mapping with integer values for from_ptransform,'
599609 f' got { chunks } '
600610 )
611+
601612 chunks = normalize_chunks (chunks , template )
602- ptransform = ptransform | _get_label ( 'validate' ) >> beam .MapTuple (
613+ ptransform = ptransform | label >> beam .MapTuple (
603614 functools .partial (
604615 _normalize_and_validate_chunk , template , chunks , split_vars
605616 )
@@ -615,6 +626,7 @@ def from_xarray(
615626 split_vars : bool = False ,
616627 previous_chunks : Mapping [str , int ] | None = None ,
617628 pipeline : beam .Pipeline | None = None ,
629+ label : str | None = None ,
618630 ) -> Dataset :
619631 """Create an xarray_beam.Dataset from an xarray.Dataset.
620632
@@ -628,13 +640,17 @@ def from_xarray(
628640 with ``normalize_chunks()``.
629641 pipeline: Beam pipeline to use for this dataset. If not provided, you will
630642 need apply a pipeline later to compute this dataset.
643+ label: A unique name for this stage of the pipeline. Defaults to ``None``,
644+ in which case a name will be generated.
631645 """
646+ if label is None :
647+ label = _get_label ('from_xarray' )
632648 template = zarr .make_template (source )
633649 if previous_chunks is None :
634650 previous_chunks = source .sizes
635651 chunks = normalize_chunks (chunks , template , split_vars , previous_chunks )
636652 ptransform = core .DatasetToChunks (source , chunks , split_vars )
637- ptransform .label = _get_label ( 'from_xarray' )
653+ ptransform .label = label
638654 if pipeline is not None :
639655 ptransform = _LazyPCollection (pipeline , ptransform )
640656 return cls (template , dict (chunks ), split_vars , ptransform )
@@ -647,6 +663,7 @@ def from_zarr(
647663 chunks : UnnormalizedChunks | None = None ,
648664 split_vars : bool = False ,
649665 pipeline : beam .Pipeline | None = None ,
666+ label : str | None = None ,
650667 ) -> Dataset :
651668 """Create an xarray_beam.Dataset from a Zarr store.
652669
@@ -659,10 +676,14 @@ def from_zarr(
659676 ptransform, or all stored in the same element.
660677 pipeline: Beam pipeline to use for this dataset. If not provided, you will
661678 need apply a pipeline later to compute this dataset.
679+ label: A unique name for this stage of the pipeline. Defaults to ``None``,
680+ in which case a name will be generated.
662681
663682 Returns:
664683 New Dataset created from the Zarr store.
665684 """
685+ if label is None :
686+ label = _get_label ('from_zarr' )
666687 source , previous_chunks = zarr .open_zarr (path )
667688 if chunks is None :
668689 chunks = previous_chunks
@@ -672,7 +693,7 @@ def from_zarr(
672693 split_vars = split_vars ,
673694 previous_chunks = previous_chunks ,
674695 )
675- result .ptransform .label = _get_label ( 'from_zarr' )
696+ result .ptransform .label = label
676697 if pipeline is not None :
677698 result ._ptransform = _LazyPCollection (pipeline , result .ptransform )
678699 return result
@@ -737,6 +758,7 @@ def to_zarr(
737758 zarr_shards : UnnormalizedChunks | None = None ,
738759 zarr_format : int | None = None ,
739760 stage_locally : bool | None = None ,
761+ label : str | None = None ,
740762 ) -> beam .PTransform | beam .PCollection :
741763 """Write this dataset to a Zarr file.
742764
@@ -773,10 +795,15 @@ def to_zarr(
773795 setup on high-latency filesystems. By default, uses local staging if
774796 possible, which is true as long as `store` is provided as as string or
775797 path.
798+ label: A unique name for this stage of the pipeline. Defaults to ``None``,
799+ in which case a name will be generated.
776800
777801 Returns:
778802 Beam transform that writes the dataset to a Zarr file.
779803 """
804+ if label is None :
805+ label = _get_label ('to_zarr' )
806+
780807 if zarr_shards is not None :
781808 zarr_shards = normalize_chunks (
782809 zarr_shards ,
@@ -825,7 +852,7 @@ def to_zarr(
825852 if zarr_shards is not None and zarr_format is None :
826853 zarr_format = 3 # required for shards
827854
828- return self .ptransform | _get_label ( 'to_zarr' ) >> zarr .ChunksToZarr (
855+ return self .ptransform | label >> zarr .ChunksToZarr (
829856 path ,
830857 self .template ,
831858 zarr_chunks = zarr_chunks ,
@@ -853,6 +880,7 @@ def map_blocks(
853880 * ,
854881 template : xarray .Dataset | None = None ,
855882 chunks : Mapping [str , int ] | None = None ,
883+ label : str | None = None ,
856884 ) -> Dataset :
857885 """Map a function over the chunks of this dataset.
858886
@@ -866,10 +894,15 @@ def map_blocks(
866894 chunks: explicit new chunks sizes created by applying ``func``. If not
867895 provided, an attempt will be made to infer the new chunks based on the
868896 existing chunks, dimensions sizes and the new template.
897+ label: A unique name for this stage of the pipeline. Defaults to ``None``,
898+ in which case a name will be generated.
869899
870900 Returns:
871901 New Dataset with updated chunks.
872902 """
903+ if label is None :
904+ label = _get_label ('map_blocks' )
905+
873906 if template is None :
874907 try :
875908 template = func (self .template )
@@ -919,7 +952,6 @@ def map_blocks(
919952 f'dataset and { new_chunk_count } in the result of map_blocks'
920953 )
921954
922- label = _get_label ('map_blocks' )
923955 func_name = getattr (func , '__name__' , None )
924956 name = f'map-blocks-{ func_name } ' if func_name else 'map-blocks'
925957 ptransform = self .ptransform | label >> beam .MapTuple (
@@ -935,6 +967,8 @@ def rechunk(
935967 split_vars : bool | None = None ,
936968 min_mem : int | None = None ,
937969 max_mem : int = 2 ** 30 ,
970+ * ,
971+ label : str | None = None ,
938972 ) -> Dataset :
939973 """Rechunk this Dataset.
940974
@@ -949,10 +983,15 @@ def rechunk(
949983 rechunking. Defaults to ``max_mem/100``.
950984 max_mem: optional maximum memory usage for an intermediate chunk in
951985 rechunking. Defaults to 1GB.
986+ label: A unique name for this stage of the pipeline. Defaults to ``None``,
987+ in which case a name will be generated.
952988
953989 Returns:
954990 New Dataset with updated chunks.
955991 """
992+ if label is None :
993+ label = _get_label ('rechunk' )
994+
956995 if split_vars is None :
957996 split_vars = self .split_vars
958997
@@ -962,7 +1001,6 @@ def rechunk(
9621001 split_vars = split_vars ,
9631002 previous_chunks = self .chunks ,
9641003 )
965- label = _get_label ('rechunk' )
9661004
9671005 pipeline , ptransform = _split_lazy_pcollection (self ._ptransform )
9681006 if isinstance (ptransform , core .DatasetToChunks ) and all (
@@ -995,21 +1033,23 @@ def rechunk(
9951033 result = rechunked if split_vars else rechunked .consolidate_variables ()
9961034 return result
9971035
998- def split_variables (self ) -> Dataset :
1036+ def split_variables (self , * , label : str | None = None ) -> Dataset :
9991037 """Split variables in this Dataset into separate chunks."""
10001038 if self .split_vars :
10011039 return self
1040+ if label is None :
1041+ label = _get_label ('split_vars' )
10021042 split_vars = True
1003- label = _get_label ('split_vars' )
10041043 ptransform = self .ptransform | label >> rechunk .SplitVariables ()
10051044 return type (self )(self .template , self .chunks , split_vars , ptransform )
10061045
1007- def consolidate_variables (self ) -> Dataset :
1046+ def consolidate_variables (self , * , label : str | None = None ) -> Dataset :
10081047 """Consolidate variables in this Dataset into a single chunk."""
10091048 if not self .split_vars :
10101049 return self
1050+ if label is None :
1051+ label = _get_label ('consolidate_vars' )
10111052 split_vars = False
1012- label = _get_label ('consolidate_vars' )
10131053 ptransform = self .ptransform | label >> rechunk .ConsolidateVariables ()
10141054 return type (self )(self .template , self .chunks , split_vars , ptransform )
10151055
@@ -1019,6 +1059,7 @@ def mean(
10191059 * ,
10201060 skipna : bool = True ,
10211061 dtype : npt .DTypeLike | None = None ,
1062+ label : str | None = None ,
10221063 ) -> Dataset :
10231064 """Compute the mean of this Dataset using Beam combiners.
10241065
@@ -1029,6 +1070,8 @@ def mean(
10291070 dim: dimension(s) to compute the mean over.
10301071 skipna: whether to skip missing data when computing the mean.
10311072 dtype: the desired dtype of the resulting Dataset.
1073+ label: A unique name for this stage of the pipeline. Defaults to ``None``,
1074+ in which case a name will be generated.
10321075
10331076 Returns:
10341077 New Dataset with the mean computed.
@@ -1039,11 +1082,12 @@ def mean(
10391082 dims = [dim ]
10401083 else :
10411084 dims = dim
1085+ if label is None :
1086+ label = _get_label (f"mean_{ '_' .join (dims )} " )
10421087 template = zarr .make_template (
10431088 self .template .mean (dim = dims , skipna = skipna , dtype = dtype )
10441089 )
10451090 new_chunks = {k : v for k , v in self .chunks .items () if k not in dims }
1046- label = _get_label (f"mean_{ '_' .join (dims )} " )
10471091 ptransform = self .ptransform | label >> combiners .MultiStageMean (
10481092 dims = dims ,
10491093 skipna = skipna ,
@@ -1056,7 +1100,9 @@ def mean(
10561100
10571101 _head = _whole_dataset_method ('head' )
10581102
1059- def head (self , ** indexers_kwargs : int ) -> Dataset :
1103+ def head (
1104+ self , * , label : str | None = None , ** indexers_kwargs : int
1105+ ) -> Dataset :
10601106 """Return a Dataset with the first N elements of each dimension."""
10611107 _ , ptransform = _split_lazy_pcollection (self ._ptransform )
10621108 if not isinstance (ptransform , core .DatasetToChunks ):
@@ -1065,11 +1111,13 @@ def head(self, **indexers_kwargs: int) -> Dataset:
10651111 'ptransform=DatasetToChunks. This dataset has '
10661112 f'ptransform={ ptransform } '
10671113 )
1068- return self ._head (** indexers_kwargs )
1114+ return self ._head (label = label , ** indexers_kwargs )
10691115
10701116 _tail = _whole_dataset_method ('tail' )
10711117
1072- def tail (self , ** indexers_kwargs : int ) -> Dataset :
1118+ def tail (
1119+ self , * , label : str | None = None , ** indexers_kwargs : int
1120+ ) -> Dataset :
10731121 """Return a Dataset with the last N elements of each dimension."""
10741122 _ , ptransform = _split_lazy_pcollection (self ._ptransform )
10751123 if not isinstance (ptransform , core .DatasetToChunks ):
@@ -1078,7 +1126,7 @@ def tail(self, **indexers_kwargs: int) -> Dataset:
10781126 'ptransform=DatasetToChunks. This dataset has '
10791127 f'ptransform={ ptransform } '
10801128 )
1081- return self ._tail (** indexers_kwargs )
1129+ return self ._tail (label = label , ** indexers_kwargs )
10821130
10831131 # thin wrappers around xarray methods
10841132 __getitem__ = _whole_dataset_method ('__getitem__' )
0 commit comments