Skip to content

Commit b09a9e3

Browse files
authored
Merge pull request #66 from CarloLucibello/master
add set_step_increment!
2 parents 1c10745 + 31a600e commit b09a9e3

File tree

3 files changed

+28
-10
lines changed

3 files changed

+28
-10
lines changed

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ it will increment the internal counter by 1.
5252

5353
If you want to increase the counter by a different amount, or prevent it from increasing, you can log the additional message
5454
`log_step_increment=N`. The default behaviour corresponds to `N=1`. If you set `N=0` the internal counter will not be modified.
55+
The defaul behaviour for logger `lg` can be changed executing `set_step_increment!(lg, N)`.
5556

5657
See the example below:
5758
```julia

src/TBLogger.jl

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mutable struct TBLogger <: AbstractLogger
33
file::IOStream
44
all_files::Dict{String, IOStream}
55
global_step::Int
6+
step_increment::Int
67
min_level::LogLevel
78
end
89

@@ -17,21 +18,27 @@ already exists.
1718
export tb_append, tb_overwrite, tb_increment
1819

1920
"""
20-
TBLogger(logdir, [tb_increment ; time=time(), purge_step::Int, min_level=Logging.Info])
21+
TBLogger(logdir[, tb_increment];
22+
time=time(),
23+
purge_step=nothing,
24+
step_increment=1,
25+
min_level=Logging.Info)
2126
2227
Creates a TensorBoardLogger in the folder `logdir`. The second (optional)
2328
argument specifies the behaviour if the `logdir` already exhists: the default
24-
choice `tb_increment` appends an increasing number 1,2... to logdir. Other
25-
choices are `tb_overwrite`, which overwrites the previous folder and `tb_append`.
29+
choice `tb_increment` appends an increasing number 1,2... to `logdir`. Other
30+
choices are `tb_overwrite`, which overwrites the previous folder, and `tb_append`.
2631
27-
If `purge_step::Int` is passed, every step before `purge_step` will be ignored
32+
If a `purge_step::Int` is passed, every step before `purge_step` will be ignored
2833
by tensorboard (usefull in the case of restarting a crashed computation).
2934
30-
`min_level=Logging.Info` specifies the minimum level of messages logged to
31-
tensorboard
35+
`min_level` specifies the minimum level of messages logged to
36+
tensorboard.
3237
"""
3338
function TBLogger(logdir="tensorboard_logs/run", overwrite=tb_increment;
34-
time=time(), purge_step::Union{Int,Nothing}=nothing,
39+
time=time(),
40+
purge_step::Union{Int,Nothing}=nothing,
41+
step_increment = 1,
3542
min_level::LogLevel=Info)
3643

3744
logdir = init_logdir(logdir, overwrite)
@@ -40,7 +47,7 @@ function TBLogger(logdir="tensorboard_logs/run", overwrite=tb_increment;
4047
all_files = Dict(fpath => evfile)
4148
start_step = something(purge_step, 0)
4249

43-
TBLogger(logdir, evfile, all_files, start_step, min_level)
50+
TBLogger(logdir, evfile, all_files, start_step, step_increment, min_level)
4451
end
4552

4653
"""
@@ -159,6 +166,16 @@ logger when no value is passed by the user.
159166
"""
160167
set_step!(lg::TBLogger, step) = lg.global_step = step
161168

169+
"""
170+
set_step_increment!(lg, increment) -> Int
171+
172+
Sets the default increment applyed to logger `lg`'s iteration counter
173+
each time logging is performed.
174+
175+
Can be overidden by passing `log_step_increment=some_increment` when logging.
176+
"""
177+
set_step_increment!(lg::TBLogger, step) = lg.global_step = step
178+
162179
"""
163180
increment_step!(lg, Δ_Step) -> Int
164181
@@ -208,7 +225,7 @@ function CoreLogging.handle_message(lg::TBLogger, level, message, _module, group
208225
id, file, line; kwargs...)
209226
# Unpack the message
210227
summ = SummaryCollection()
211-
i_step = 1 # :log_step_increment default value
228+
i_step = lg.step_increment # :log_step_increment default value
212229

213230
if !isempty(kwargs)
214231
data = Vector{Pair{String,Any}}()

src/TensorBoardLogger.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using StatsBase #TODO: remove this. Only needed to compute histogram bins.
1010
using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info,
1111
handle_message, shouldlog, min_enabled_level, catch_exceptions
1212

13-
export TBLogger, reset!, set_step!, increment_step!
13+
export TBLogger, reset!, set_step!, increment_step!, set_step_increment!
1414
export log_histogram, log_value, log_vector, log_text, log_image, log_images,
1515
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar
1616
export map_summaries

0 commit comments

Comments
 (0)