Skip to content

Commit 1c10745

Browse files
authored
Merge pull request #64 from fissoreg/master
Support for CUSTOM_SCALARS
2 parents 9fbd7e0 + 2295b51 commit 1c10745

File tree

7 files changed

+214
-2
lines changed

7 files changed

+214
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorBoardLogger"
22
uuid = "899adc3e-224a-11e9-021f-63837185c80f"
33
authors = ["Filippo Vicentini <filippovicentini@gmail.com>"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
77
CRC32c = "8bf52ea8-c179-5cab-976a-9e18b702a9bc"

docs/src/explicit_interface.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,19 @@ log_audios
3838
```@docs
3939
log_embeddings
4040
```
41+
42+
## Custom Scalars plugin
43+
See [TensorBoard Custom Scalar page](https://github.com/tensorflow/tensorboard/tree/master/tensorboard/plugins/custom_scalar).
44+
45+
For example, to combine in the same plot panel the two curves logged under tags `"Curve/1"` and `"Curve/2"` you can run once the command:
46+
```julia
47+
layout = Dict("Cat" => Dict("Curve" => ("Multiline", ["Curve/1", "Curve/2"])))
48+
49+
log_custom_scalar(lg, layout)
50+
51+
```
52+
53+
See also the documentation below
54+
```@docs
55+
log_custom_scalars
56+
```

proto/layout.proto

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
syntax = "proto3";
2+
3+
package tensorboard;
4+
5+
/**
6+
* Encapsulates information on a single chart. Many charts appear in a category.
7+
*/
8+
message Chart {
9+
// The title shown atop this chart. Optional. Defaults to 'untitled'.
10+
string title = 1;
11+
12+
// The content of the chart. This depends on the type of the chart.
13+
oneof content {
14+
MultilineChartContent multiline = 2;
15+
MarginChartContent margin = 3;
16+
}
17+
}
18+
19+
/**
20+
* Encapsulates information on a single line chart. This line chart may have
21+
* lines associated with several tags.
22+
*/
23+
message MultilineChartContent {
24+
// A list of regular expressions for tags that should appear in this chart.
25+
// Tags are matched from beginning to end. Each regex captures a set of tags.
26+
repeated string tag = 1;
27+
}
28+
29+
/**
30+
* Encapsulates information on a single margin chart. A margin chart uses fill
31+
* area to visualize lower and upper bounds that surround a value.
32+
*/
33+
message MarginChartContent {
34+
/**
35+
* Encapsulates a tag of data for the chart.
36+
*/
37+
message Series {
38+
// The exact tag string associated with the scalar summaries making up the
39+
// main value between the bounds.
40+
string value = 1;
41+
42+
// The exact tag string associated with the scalar summaries making up the
43+
// lower bound.
44+
string lower = 2;
45+
46+
// The exact tag string associated with the scalar summaries making up the
47+
// upper bound.
48+
string upper = 3;
49+
}
50+
51+
// A list of data series to include within this margin chart.
52+
repeated Series series = 1;
53+
}
54+
55+
/**
56+
* A category contains a group of charts. Each category maps to a collapsible
57+
* within the dashboard.
58+
*/
59+
message Category {
60+
// This string appears atop each grouping of charts within the dashboard.
61+
string title = 1;
62+
63+
// Encapsulates data on charts to be shown in the category.
64+
repeated Chart chart = 2;
65+
66+
// Whether this category should be initially closed. False by default.
67+
bool closed = 3;
68+
}
69+
70+
/**
71+
* A layout encapsulates how charts are laid out within the custom scalars
72+
* dashboard.
73+
*/
74+
message Layout {
75+
// Version `0` is the only supported version.
76+
int32 version = 1;
77+
78+
// The categories here are rendered from top to bottom.
79+
repeated Category category = 2;
80+
}

src/Loggers/LogCustomScalar.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# possible chart types
2+
@enum tb_chart_type tb_multiline=1 tb_margin=2
3+
4+
"""
5+
log_custom_scalar(logger, layout::AbstractDict; step = step(logger))
6+
7+
Groups multiple scalars in the same plot to be visualized by the CUSTOM_SCALARS
8+
plugin. Note that this function sets the metadata: the actual values must be
9+
logged separately with `log_value` and referenced with the correct tag.
10+
11+
The `layout` argument is structured as follows:
12+
13+
layout = Dict(category => Dict(name => (chart_type, [tag1, tag2, ...])))
14+
15+
where `category` is the main tag for the plot, `name` is the plot's name,
16+
`chart_type` is one between `tb_multiline` and `tb_margin` and the array of tags
17+
contains the actual references to the logged scalars.
18+
"""
19+
function log_custom_scalar(logger::TBLogger, layout::AbstractDict; step = nothing)
20+
summ = SummaryCollection(custom_scalar_summary(layout))
21+
write_event(logger.file, make_event(logger, summ, step = step))
22+
end
23+
24+
function chart(name::String, metadata::Tuple{tb_chart_type, AbstractArray})
25+
chart_type, tags = metadata
26+
27+
if chart_type == tb_multiline
28+
content = MultilineChartContent(tag = tags)
29+
return Chart(title = name, multiline = content)
30+
elseif chart_type == tb_margin
31+
@assert length(tags) == 3
32+
args = Dict(k => v for (k, v) in zip([:value, :lower, :upper], tags))
33+
content = MarginChartContent(
34+
series = [MarginChartContent_Series(; args...)])
35+
return Chart(title = name, margin = content)
36+
else
37+
@error "The chart type must be `tb_multiline` or `tb_margin`"
38+
end
39+
end
40+
41+
function charts(dict::AbstractDict)
42+
[chart(name, meta) for (name, meta) in zip(keys(dict), values(dict))]
43+
end
44+
45+
function custom_scalar_summary(layout)
46+
cat_spec = zip(keys(layout), values(layout))
47+
categories = [Category(title = k, chart = charts(c)) for (k, c) in cat_spec]
48+
49+
layout = Layout(category = categories)
50+
plugin_data = SummaryMetadata_PluginData(plugin_name = "custom_scalars")
51+
smd = SummaryMetadata(plugin_data = plugin_data)
52+
cs_tensor = TensorProto(dtype = _DataType.DT_STRING,
53+
string_val = [serialize_proto(layout)],
54+
tensor_shape = TensorShapeProto())
55+
56+
Summary_Value(tag = "custom_scalars__config__",
57+
tensor = cs_tensor,
58+
metadata = smd)
59+
end

src/TensorBoardLogger.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@ using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info,
1212

1313
export TBLogger, reset!, set_step!, increment_step!
1414
export log_histogram, log_value, log_vector, log_text, log_image, log_images,
15-
log_audio, log_audios, log_graph, log_embeddings
15+
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar
1616
export map_summaries
1717

1818
export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC,
1919
CHW, CWH,HWN, WHN, NHW, NWH, HWCN, WHCN, CHWN, CWHN, NHWC, NWHC, NCHW, NCWH
2020

21+
# Custom Scalar Plugin
22+
export tb_multiline, tb_margin
23+
2124
# Wrapper types
2225
export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios
2326

@@ -36,6 +39,8 @@ include("protojl/types_pb.jl")
3639
include("protojl/summary_pb.jl")
3740
include("protojl/event_pb.jl")
3841
include("protojl/plugin_text_pb.jl")
42+
include("protojl/tensorboard.jl")
43+
include("protojl/layout_pb.jl")
3944

4045
include("PNG.jl")
4146
using .PNGImage
@@ -54,6 +59,9 @@ include("Loggers/LogHistograms.jl")
5459
include("Loggers/LogAudio.jl")
5560
include("Loggers/LogEmbeddings.jl")
5661

62+
# Custom Scalar Plugin
63+
include("Loggers/LogCustomScalar.jl")
64+
5765
include("logger_dispatch.jl")
5866
include("logger_dispatch_overrides.jl")
5967

src/protojl/layout_pb.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# syntax: proto3
2+
using ProtoBuf
3+
import ProtoBuf.meta
4+
5+
mutable struct MultilineChartContent <: ProtoType
6+
tag::Base.Vector{AbstractString}
7+
MultilineChartContent(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
8+
end #mutable struct MultilineChartContent
9+
10+
mutable struct MarginChartContent_Series <: ProtoType
11+
value::AbstractString
12+
lower::AbstractString
13+
upper::AbstractString
14+
MarginChartContent_Series(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
15+
end #mutable struct MarginChartContent_Series
16+
17+
mutable struct MarginChartContent <: ProtoType
18+
series::Base.Vector{MarginChartContent_Series}
19+
MarginChartContent(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
20+
end #mutable struct MarginChartContent
21+
22+
mutable struct Chart <: ProtoType
23+
title::AbstractString
24+
multiline::MultilineChartContent
25+
margin::MarginChartContent
26+
Chart(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
27+
end #mutable struct Chart
28+
const __oneofs_Chart = Int[0,1,1]
29+
const __oneof_names_Chart = [Symbol("content")]
30+
meta(t::Type{Chart}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, __oneofs_Chart, __oneof_names_Chart, ProtoBuf.DEF_FIELD_TYPES)
31+
32+
mutable struct Category <: ProtoType
33+
title::AbstractString
34+
chart::Base.Vector{Chart}
35+
closed::Bool
36+
Category(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
37+
end #mutable struct Category
38+
39+
mutable struct Layout <: ProtoType
40+
version::Int32
41+
category::Base.Vector{Category}
42+
Layout(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o)
43+
end #mutable struct Layout
44+
45+
export Chart, MultilineChartContent, MarginChartContent_Series, MarginChartContent, Category, Layout

src/protojl/tensorboard.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
module tensorboard
2+
const _ProtoBuf_Top_ = @static isdefined(parentmodule(@__MODULE__), :_ProtoBuf_Top_) ? (parentmodule(@__MODULE__))._ProtoBuf_Top_ : parentmodule(@__MODULE__)
3+
include("layout_pb.jl")
4+
end

0 commit comments

Comments
 (0)