Skip to content

Commit 7d98a7d

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add label parameter to xarray_beam.Dataset methods.
This change allows users to provide custom labels for Beam stages created by `xarray_beam.Dataset` methods like `from_xarray`, `to_zarr`, `head`, `tail`, `transpose`, and others. If no label is provided, a default label is generated as before. This is useful for annotating stages in Beam pipelines, e.g., `to_zarr(..., label='my_zarr')` can be used to ensure `my_zarr` shows up in the name of the relevant Beam stage. PiperOrigin-RevId: 828180260
1 parent c1c4dce commit 7d98a7d

File tree

2 files changed

+89
-23
lines changed

2 files changed

+89
-23
lines changed

xarray_beam/_src/dataset.py

Lines changed: 71 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import math
3636
import operator
3737
import os.path
38+
import re
3839
import tempfile
3940
import textwrap
4041
import types
@@ -364,25 +365,28 @@ def _concat_labels(label1: str | None, label2: str) -> str:
364365
def _whole_dataset_method(method_name: str):
365366
"""Helper function for defining a method with a fast-path for lazy data."""
366367

367-
def method(self: Dataset, *args, **kwargs) -> Dataset:
368+
def method(
369+
self: Dataset, *args, label: str | None = None, **kwargs
370+
) -> Dataset:
371+
if label is None:
372+
label = _get_label(method_name)
373+
368374
func = operator.methodcaller(method_name, *args, **kwargs)
369375
template = zarr.make_template(func(self.template))
370376
chunks = {k: v for k, v in self.chunks.items() if k in template.dims}
371377

372-
label = _get_label(method_name)
373-
374-
pipeline, ptransform = _split_lazy_pcollection(self._ptransform)
375-
if isinstance(ptransform, core.DatasetToChunks):
378+
pipeline, old_ptransform = _split_lazy_pcollection(self._ptransform)
379+
if isinstance(old_ptransform, core.DatasetToChunks):
376380
# Some transformations (e.g., indexing) can be applied much less
377381
# expensively to xarray.Dataset objects rather than via Xarray-Beam. Try
378382
# to preserve this option for downstream transformations if possible.
379-
dataset = func(ptransform.dataset)
383+
dataset = func(old_ptransform.dataset)
380384
ptransform = core.DatasetToChunks(dataset, chunks, self.split_vars)
381-
ptransform.label = _concat_labels(ptransform.label, label)
385+
ptransform.label = _concat_labels(old_ptransform.label, label)
382386
if pipeline is not None:
383387
ptransform = _LazyPCollection(pipeline, ptransform)
384388
else:
385-
ptransform = self.ptransform | label >> beam.MapTuple(
389+
ptransform = old_ptransform | label >> beam.MapTuple(
386390
functools.partial(
387391
_apply_to_each_chunk, func, method_name, self.chunks, chunks
388392
)
@@ -558,6 +562,7 @@ def from_ptransform(
558562
template: xarray.Dataset,
559563
chunks: Mapping[str | types.EllipsisType, int],
560564
split_vars: bool = False,
565+
label: str | None = None,
561566
) -> Dataset:
562567
"""Create an xarray_beam.Dataset from a Beam PTransform.
563568
@@ -584,10 +589,15 @@ def from_ptransform(
584589
except for the last chunk in each dimension, which may be smaller.
585590
split_vars: A boolean indicating whether the chunks in ``ptransform`` are
586591
split across variables, or if each chunk contains all variables.
592+
label: A unique name for this stage of the pipeline. Defaults to ``None``,
593+
in which case a name will be generated.
587594
588595
Returns:
589596
An ``xarray_beam.Dataset`` instance wrapping the PTransform.
590597
"""
598+
if label is None:
599+
label = _get_label('from_ptransform')
600+
591601
if not isinstance(chunks, Mapping):
592602
raise TypeError(
593603
f'chunks must be a mapping for from_ptransform, got {chunks}'
@@ -598,8 +608,9 @@ def from_ptransform(
598608
'chunks must be a mapping with integer values for from_ptransform,'
599609
f' got {chunks}'
600610
)
611+
601612
chunks = normalize_chunks(chunks, template)
602-
ptransform = ptransform | _get_label('validate') >> beam.MapTuple(
613+
ptransform = ptransform | label >> beam.MapTuple(
603614
functools.partial(
604615
_normalize_and_validate_chunk, template, chunks, split_vars
605616
)
@@ -615,6 +626,7 @@ def from_xarray(
615626
split_vars: bool = False,
616627
previous_chunks: Mapping[str, int] | None = None,
617628
pipeline: beam.Pipeline | None = None,
629+
label: str | None = None,
618630
) -> Dataset:
619631
"""Create an xarray_beam.Dataset from an xarray.Dataset.
620632
@@ -628,13 +640,17 @@ def from_xarray(
628640
with ``normalize_chunks()``.
629641
pipeline: Beam pipeline to use for this dataset. If not provided, you will
630642
need apply a pipeline later to compute this dataset.
643+
label: A unique name for this stage of the pipeline. Defaults to ``None``,
644+
in which case a name will be generated.
631645
"""
646+
if label is None:
647+
label = _get_label('from_xarray')
632648
template = zarr.make_template(source)
633649
if previous_chunks is None:
634650
previous_chunks = source.sizes
635651
chunks = normalize_chunks(chunks, template, split_vars, previous_chunks)
636652
ptransform = core.DatasetToChunks(source, chunks, split_vars)
637-
ptransform.label = _get_label('from_xarray')
653+
ptransform.label = label
638654
if pipeline is not None:
639655
ptransform = _LazyPCollection(pipeline, ptransform)
640656
return cls(template, dict(chunks), split_vars, ptransform)
@@ -647,6 +663,7 @@ def from_zarr(
647663
chunks: UnnormalizedChunks | None = None,
648664
split_vars: bool = False,
649665
pipeline: beam.Pipeline | None = None,
666+
label: str | None = None,
650667
) -> Dataset:
651668
"""Create an xarray_beam.Dataset from a Zarr store.
652669
@@ -659,10 +676,14 @@ def from_zarr(
659676
ptransform, or all stored in the same element.
660677
pipeline: Beam pipeline to use for this dataset. If not provided, you will
661678
need apply a pipeline later to compute this dataset.
679+
label: A unique name for this stage of the pipeline. Defaults to ``None``,
680+
in which case a name will be generated.
662681
663682
Returns:
664683
New Dataset created from the Zarr store.
665684
"""
685+
if label is None:
686+
label = _get_label('from_zarr')
666687
source, previous_chunks = zarr.open_zarr(path)
667688
if chunks is None:
668689
chunks = previous_chunks
@@ -672,7 +693,7 @@ def from_zarr(
672693
split_vars=split_vars,
673694
previous_chunks=previous_chunks,
674695
)
675-
result.ptransform.label = _get_label('from_zarr')
696+
result.ptransform.label = label
676697
if pipeline is not None:
677698
result._ptransform = _LazyPCollection(pipeline, result.ptransform)
678699
return result
@@ -737,6 +758,7 @@ def to_zarr(
737758
zarr_shards: UnnormalizedChunks | None = None,
738759
zarr_format: int | None = None,
739760
stage_locally: bool | None = None,
761+
label: str | None = None,
740762
) -> beam.PTransform | beam.PCollection:
741763
"""Write this dataset to a Zarr file.
742764
@@ -773,10 +795,15 @@ def to_zarr(
773795
setup on high-latency filesystems. By default, uses local staging if
774796
possible, which is true as long as `store` is provided as as string or
775797
path.
798+
label: A unique name for this stage of the pipeline. Defaults to ``None``,
799+
in which case a name will be generated.
776800
777801
Returns:
778802
Beam transform that writes the dataset to a Zarr file.
779803
"""
804+
if label is None:
805+
label = _get_label('to_zarr')
806+
780807
if zarr_shards is not None:
781808
zarr_shards = normalize_chunks(
782809
zarr_shards,
@@ -825,7 +852,7 @@ def to_zarr(
825852
if zarr_shards is not None and zarr_format is None:
826853
zarr_format = 3 # required for shards
827854

828-
return self.ptransform | _get_label('to_zarr') >> zarr.ChunksToZarr(
855+
return self.ptransform | label >> zarr.ChunksToZarr(
829856
path,
830857
self.template,
831858
zarr_chunks=zarr_chunks,
@@ -853,6 +880,7 @@ def map_blocks(
853880
*,
854881
template: xarray.Dataset | None = None,
855882
chunks: Mapping[str, int] | None = None,
883+
label: str | None = None,
856884
) -> Dataset:
857885
"""Map a function over the chunks of this dataset.
858886
@@ -866,10 +894,15 @@ def map_blocks(
866894
chunks: explicit new chunks sizes created by applying ``func``. If not
867895
provided, an attempt will be made to infer the new chunks based on the
868896
existing chunks, dimensions sizes and the new template.
897+
label: A unique name for this stage of the pipeline. Defaults to ``None``,
898+
in which case a name will be generated.
869899
870900
Returns:
871901
New Dataset with updated chunks.
872902
"""
903+
if label is None:
904+
label = _get_label('map_blocks')
905+
873906
if template is None:
874907
try:
875908
template = func(self.template)
@@ -919,7 +952,6 @@ def map_blocks(
919952
f'dataset and {new_chunk_count} in the result of map_blocks'
920953
)
921954

922-
label = _get_label('map_blocks')
923955
func_name = getattr(func, '__name__', None)
924956
name = f'map-blocks-{func_name}' if func_name else 'map-blocks'
925957
ptransform = self.ptransform | label >> beam.MapTuple(
@@ -935,6 +967,8 @@ def rechunk(
935967
split_vars: bool | None = None,
936968
min_mem: int | None = None,
937969
max_mem: int = 2**30,
970+
*,
971+
label: str | None = None,
938972
) -> Dataset:
939973
"""Rechunk this Dataset.
940974
@@ -949,10 +983,15 @@ def rechunk(
949983
rechunking. Defaults to ``max_mem/100``.
950984
max_mem: optional maximum memory usage for an intermediate chunk in
951985
rechunking. Defaults to 1GB.
986+
label: A unique name for this stage of the pipeline. Defaults to ``None``,
987+
in which case a name will be generated.
952988
953989
Returns:
954990
New Dataset with updated chunks.
955991
"""
992+
if label is None:
993+
label = _get_label('rechunk')
994+
956995
if split_vars is None:
957996
split_vars = self.split_vars
958997

@@ -962,7 +1001,6 @@ def rechunk(
9621001
split_vars=split_vars,
9631002
previous_chunks=self.chunks,
9641003
)
965-
label = _get_label('rechunk')
9661004

9671005
pipeline, ptransform = _split_lazy_pcollection(self._ptransform)
9681006
if isinstance(ptransform, core.DatasetToChunks) and all(
@@ -995,21 +1033,23 @@ def rechunk(
9951033
result = rechunked if split_vars else rechunked.consolidate_variables()
9961034
return result
9971035

998-
def split_variables(self) -> Dataset:
1036+
def split_variables(self, *, label: str | None = None) -> Dataset:
9991037
"""Split variables in this Dataset into separate chunks."""
10001038
if self.split_vars:
10011039
return self
1040+
if label is None:
1041+
label = _get_label('split_vars')
10021042
split_vars = True
1003-
label = _get_label('split_vars')
10041043
ptransform = self.ptransform | label >> rechunk.SplitVariables()
10051044
return type(self)(self.template, self.chunks, split_vars, ptransform)
10061045

1007-
def consolidate_variables(self) -> Dataset:
1046+
def consolidate_variables(self, *, label: str | None = None) -> Dataset:
10081047
"""Consolidate variables in this Dataset into a single chunk."""
10091048
if not self.split_vars:
10101049
return self
1050+
if label is None:
1051+
label = _get_label('consolidate_vars')
10111052
split_vars = False
1012-
label = _get_label('consolidate_vars')
10131053
ptransform = self.ptransform | label >> rechunk.ConsolidateVariables()
10141054
return type(self)(self.template, self.chunks, split_vars, ptransform)
10151055

@@ -1019,6 +1059,7 @@ def mean(
10191059
*,
10201060
skipna: bool = True,
10211061
dtype: npt.DTypeLike | None = None,
1062+
label: str | None = None,
10221063
) -> Dataset:
10231064
"""Compute the mean of this Dataset using Beam combiners.
10241065
@@ -1029,6 +1070,8 @@ def mean(
10291070
dim: dimension(s) to compute the mean over.
10301071
skipna: whether to skip missing data when computing the mean.
10311072
dtype: the desired dtype of the resulting Dataset.
1073+
label: A unique name for this stage of the pipeline. Defaults to ``None``,
1074+
in which case a name will be generated.
10321075
10331076
Returns:
10341077
New Dataset with the mean computed.
@@ -1039,11 +1082,12 @@ def mean(
10391082
dims = [dim]
10401083
else:
10411084
dims = dim
1085+
if label is None:
1086+
label = _get_label(f"mean_{'_'.join(dims)}")
10421087
template = zarr.make_template(
10431088
self.template.mean(dim=dims, skipna=skipna, dtype=dtype)
10441089
)
10451090
new_chunks = {k: v for k, v in self.chunks.items() if k not in dims}
1046-
label = _get_label(f"mean_{'_'.join(dims)}")
10471091
ptransform = self.ptransform | label >> combiners.MultiStageMean(
10481092
dims=dims,
10491093
skipna=skipna,
@@ -1056,7 +1100,9 @@ def mean(
10561100

10571101
_head = _whole_dataset_method('head')
10581102

1059-
def head(self, **indexers_kwargs: int) -> Dataset:
1103+
def head(
1104+
self, *, label: str | None = None, **indexers_kwargs: int
1105+
) -> Dataset:
10601106
"""Return a Dataset with the first N elements of each dimension."""
10611107
_, ptransform = _split_lazy_pcollection(self._ptransform)
10621108
if not isinstance(ptransform, core.DatasetToChunks):
@@ -1065,11 +1111,13 @@ def head(self, **indexers_kwargs: int) -> Dataset:
10651111
'ptransform=DatasetToChunks. This dataset has '
10661112
f'ptransform={ptransform}'
10671113
)
1068-
return self._head(**indexers_kwargs)
1114+
return self._head(label=label, **indexers_kwargs)
10691115

10701116
_tail = _whole_dataset_method('tail')
10711117

1072-
def tail(self, **indexers_kwargs: int) -> Dataset:
1118+
def tail(
1119+
self, *, label: str | None = None, **indexers_kwargs: int
1120+
) -> Dataset:
10731121
"""Return a Dataset with the last N elements of each dimension."""
10741122
_, ptransform = _split_lazy_pcollection(self._ptransform)
10751123
if not isinstance(ptransform, core.DatasetToChunks):
@@ -1078,7 +1126,7 @@ def tail(self, **indexers_kwargs: int) -> Dataset:
10781126
'ptransform=DatasetToChunks. This dataset has '
10791127
f'ptransform={ptransform}'
10801128
)
1081-
return self._tail(**indexers_kwargs)
1129+
return self._tail(label=label, **indexers_kwargs)
10821130

10831131
# thin wrappers around xarray methods
10841132
__getitem__ = _whole_dataset_method('__getitem__')

xarray_beam/_src/dataset_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,24 @@ def test_pipe(self):
978978
actual = mapped_ds.collect_with_direct_runner()
979979
xarray.testing.assert_identical(actual, expected)
980980

981+
def test_custom_label(self):
982+
ds = xarray.Dataset({'foo': ('x', np.arange(10))})
983+
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}, label='my_from_xarray')
984+
self.assertEqual(beam_ds.ptransform.label, 'my_from_xarray')
985+
986+
temp_dir = self.create_tempdir().full_path
987+
to_zarr = beam_ds.to_zarr(temp_dir, label='my_to_zarr')
988+
self.assertEqual(to_zarr.label, 'my_from_xarray|my_to_zarr')
989+
990+
head = beam_ds.head(x=2, label='my_head')
991+
self.assertEqual(head.ptransform.label, 'my_from_xarray|my_head')
992+
993+
tail = beam_ds.tail(x=2, label='my_tail')
994+
self.assertEqual(tail.ptransform.label, 'my_from_xarray|my_tail')
995+
996+
transpose = beam_ds.transpose(label='my_transpose')
997+
self.assertEqual(transpose.ptransform.label, 'my_from_xarray|my_transpose')
998+
981999

9821000
class MapBlocksTest(test_util.TestCase):
9831001

0 commit comments

Comments
 (0)