|
| 1 | +from dataclasses import dataclass |
| 2 | +from functools import partial |
| 3 | +from typing import Any |
| 4 | + |
| 5 | +import ipywidgets |
1 | 6 | import numpy as np
|
| 7 | +import xarray as xr |
| 8 | +from lonboard import Map |
| 9 | + |
| 10 | + |
| 11 | +def on_slider_change(change, container): |
| 12 | + owner = change["owner"] |
| 13 | + dim = owner.description |
| 14 | + |
| 15 | + indexers = { |
| 16 | + slider.description: slider.value |
| 17 | + for slider in container.dimension_sliders.children |
| 18 | + if slider.description != dim |
| 19 | + } | {dim: change["new"]} |
| 20 | + new_slice = container.obj.isel(indexers) |
| 21 | + |
| 22 | + colors = colorize(new_slice.variable, **container.colorize_kwargs) |
| 23 | + |
| 24 | + layer = container.map.layers[0] |
| 25 | + layer.get_fill_color = colors |
| 26 | + |
| 27 | + |
| 28 | +@dataclass |
| 29 | +class MapContainer: |
| 30 | + """container for the map, any control widgets and the data object""" |
| 31 | + |
| 32 | + dimension_sliders: ipywidgets.VBox |
| 33 | + map: Map |
| 34 | + obj: xr.DataArray |
| 35 | + |
| 36 | + colorize_kwargs: dict[str, Any] |
| 37 | + |
| 38 | + def render(self): |
| 39 | + # add any additional control widgets here |
| 40 | + control_box = ipywidgets.HBox([self.dimension_sliders]) |
| 41 | + |
| 42 | + return ipywidgets.VBox([self.map, control_box]) |
2 | 43 |
|
3 | 44 |
|
4 | 45 | def create_arrow_table(polygons, arr, coords=None):
|
@@ -39,35 +80,71 @@ def normalize(var, center=None):
|
39 | 80 | return normalizer(var.data)
|
40 | 81 |
|
41 | 82 |
|
| 83 | +def colorize(var, *, center, colormap, alpha): |
| 84 | + from lonboard.colormap import apply_continuous_cmap |
| 85 | + |
| 86 | + normalized_data = normalize(var, center=center) |
| 87 | + |
| 88 | + return apply_continuous_cmap(normalized_data, colormap, alpha=alpha) |
| 89 | + |
| 90 | + |
42 | 91 | def explore(
|
43 | 92 | arr,
|
44 |
| - cell_dim="cells", |
45 | 93 | cmap="viridis",
|
46 | 94 | center=None,
|
47 | 95 | alpha=None,
|
48 | 96 | coords=None,
|
49 | 97 | ):
|
50 | 98 | import lonboard
|
51 | 99 | from lonboard import SolidPolygonLayer
|
52 |
| - from lonboard.colormap import apply_continuous_cmap |
53 | 100 | from matplotlib import colormaps
|
54 | 101 |
|
55 |
| - if len(arr.dims) != 1 or cell_dim not in arr.dims: |
| 102 | + # guaranteed to be 1D |
| 103 | + cell_id_coord = arr.dggs.coord |
| 104 | + [cell_dim] = cell_id_coord.dims |
| 105 | + |
| 106 | + if cell_dim not in arr.dims: |
56 | 107 | raise ValueError(
|
57 |
| - f"exploration only works with a single dimension ('{cell_dim}')" |
| 108 | + f"exploration plotting only works with a spatial dimension ('{cell_dim}')" |
58 | 109 | )
|
59 | 110 |
|
60 |
| - cell_ids = arr.dggs.coord.data |
| 111 | + cell_ids = cell_id_coord.data |
61 | 112 | grid_info = arr.dggs.grid_info
|
62 | 113 |
|
63 | 114 | polygons = grid_info.cell_boundaries(cell_ids, backend="geoarrow")
|
64 | 115 |
|
65 |
| - normalized_data = normalize(arr.variable, center=center) |
| 116 | + initial_indexers = {dim: 0 for dim in arr.dims if dim != cell_dim} |
| 117 | + initial_arr = arr.isel(initial_indexers) |
66 | 118 |
|
67 | 119 | colormap = colormaps[cmap] if isinstance(cmap, str) else cmap
|
68 |
| - colors = apply_continuous_cmap(normalized_data, colormap, alpha=alpha) |
| 120 | + colors = colorize(initial_arr, center=center, alpha=alpha, colormap=colormap) |
69 | 121 |
|
70 |
| - table = create_arrow_table(polygons, arr, coords=coords) |
| 122 | + table = create_arrow_table(polygons, initial_arr, coords=coords) |
71 | 123 | layer = SolidPolygonLayer(table=table, filled=True, get_fill_color=colors)
|
72 | 124 |
|
73 |
| - return lonboard.Map(layer) |
| 125 | + map_ = lonboard.Map(layer) |
| 126 | + |
| 127 | + if not initial_indexers: |
| 128 | + # 1D data |
| 129 | + return map_ |
| 130 | + |
| 131 | + sliders = ipywidgets.VBox( |
| 132 | + [ |
| 133 | + ipywidgets.IntSlider(min=0, max=arr.sizes[dim] - 1, description=dim) |
| 134 | + for dim in arr.dims |
| 135 | + if dim != cell_dim |
| 136 | + ] |
| 137 | + ) |
| 138 | + |
| 139 | + container = MapContainer( |
| 140 | + sliders, |
| 141 | + map_, |
| 142 | + arr, |
| 143 | + colorize_kwargs={"alpha": alpha, "center": center, "colormap": colormap}, |
| 144 | + ) |
| 145 | + |
| 146 | + # connect slider with map |
| 147 | + for slider in sliders.children: |
| 148 | + slider.observe(partial(on_slider_change, container=container), names="value") |
| 149 | + |
| 150 | + return container.render() |
0 commit comments