|
| 1 | +# possible chart types |
| 2 | +@enum tb_chart_type tb_multiline=1 tb_margin=2 |
| 3 | + |
| 4 | +""" |
| 5 | +log_custom_scalar(logger, layout::AbstractDict; step = step(logger)) |
| 6 | +
|
| 7 | +Groups multiple scalars in the same plot to be visualized by the CUSTOM_SCALARS |
| 8 | +plugin. Note that this function sets the metadata: the actual values must be |
| 9 | +logged separately with `log_value` and referenced with the correct tag. |
| 10 | +
|
| 11 | +The `layout` argument is structured as follows: |
| 12 | +
|
| 13 | +layout = Dict(category => Dict(name => (chart_type, [tag1, tag2, ...]))) |
| 14 | +
|
| 15 | +where `category` is the main tag for the plot, `name` is the plot's name, |
| 16 | +`chart_type` is one between `tb_multiline` and `tb_margin` and the array of tags |
| 17 | +contains the actual references to the logged scalars. |
| 18 | +""" |
| 19 | +function log_custom_scalar(logger::TBLogger, layout::AbstractDict; step = nothing) |
| 20 | + summ = SummaryCollection(custom_scalar_summary(layout)) |
| 21 | + write_event(logger.file, make_event(logger, summ, step = step)) |
| 22 | +end |
| 23 | + |
| 24 | +function chart(name::String, metadata::Tuple{tb_chart_type, AbstractArray}) |
| 25 | + chart_type, tags = metadata |
| 26 | + |
| 27 | + if chart_type == tb_multiline |
| 28 | + content = MultilineChartContent(tag = tags) |
| 29 | + return Chart(title = name, multiline = content) |
| 30 | + elseif chart_type == tb_margin |
| 31 | + @assert length(tags) == 3 |
| 32 | + args = Dict(k => v for (k, v) in zip([:value, :lower, :upper], tags)) |
| 33 | + content = MarginChartContent( |
| 34 | + series = [MarginChartContent_Series(; args...)]) |
| 35 | + return Chart(title = name, margin = content) |
| 36 | + else |
| 37 | + @error "The chart type must be `tb_multiline` or `tb_margin`" |
| 38 | + end |
| 39 | +end |
| 40 | + |
| 41 | +function charts(dict::AbstractDict) |
| 42 | + [chart(name, meta) for (name, meta) in zip(keys(dict), values(dict))] |
| 43 | +end |
| 44 | + |
| 45 | +function custom_scalar_summary(layout) |
| 46 | + cat_spec = zip(keys(layout), values(layout)) |
| 47 | + categories = [Category(title = k, chart = charts(c)) for (k, c) in cat_spec] |
| 48 | + |
| 49 | + layout = Layout(category = categories) |
| 50 | + plugin_data = SummaryMetadata_PluginData(plugin_name = "custom_scalars") |
| 51 | + smd = SummaryMetadata(plugin_data = plugin_data) |
| 52 | + cs_tensor = TensorProto(dtype = _DataType.DT_STRING, |
| 53 | + string_val = [serialize_proto(layout)], |
| 54 | + tensor_shape = TensorShapeProto()) |
| 55 | + |
| 56 | + Summary_Value(tag = "custom_scalars__config__", |
| 57 | + tensor = cs_tensor, |
| 58 | + metadata = smd) |
| 59 | +end |
0 commit comments