Skip to content

Commit

Permalink
Merge pull request #136 from darsnack/log-value
Browse files Browse the repository at this point in the history
Add ability to record and log arbitrary learner values
  • Loading branch information
darsnack authored Sep 14, 2022
2 parents f02dd9b + 7f6ceb9 commit 2bb79b3
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FluxTraining"
uuid = "7bf95e4d-ca32-48da-9824-f0dc5310474f"
authors = ["lorenzoh <[email protected]>"]
version = "0.3.2"
version = "0.3.3"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand Down
3 changes: 3 additions & 0 deletions src/FluxTraining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ include("./callbacks/earlystopping.jl")
include("./callbacks/custom.jl")
include("./callbacks/metrics.jl")
include("./callbacks/recorder.jl")
include("./callbacks/trace.jl")
include("./callbacks/sanitycheck.jl")

# hyperparameter scheduling
Expand Down Expand Up @@ -88,12 +89,14 @@ export AbstractCallback,
ProgressPrinter,
Metrics,
MetricsPrinter,
Traces,
TrainingPhase,
ValidationPhase,
Schedule,
Scheduler,
LogMetrics,
SmoothLoss,
LogTraces,
LogHistograms,
LogHyperParams,
LogVisualization,
Expand Down
37 changes: 37 additions & 0 deletions src/callbacks/logging/logger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,43 @@ function log_to(backends::Tuple, loggable, name, i; group = ())
end
end

"""
LogTraces(backends...) <: Callback
Callback that logs step traces to one or more [`LoggerBackend`](#)s.
See also [`LoggerBackend`](#), [`Loggables.Loggable`](#), [`log_to`](#),
[`TensorBoardBackend`](#)
Example:
```julia
logcb = LogTraces(TensorBoardBackend("tblogs"))
tracer = Traces((trace = learner -> learner.step.loss^2,), TrainingPhase)
Learner(model, lossfn; callbacks=[tracer, logcb])
```
"""
struct LogTraces <: Callback
backends::Tuple
LogTraces(backends...) = new(backends)
end

stateaccess(::LogTraces) = (cbstate = (history = Read(), tracehistory = Read()),)

function on(::StepEnd, phase, logger::LogTraces, learner)
history = learner.cbstate.history[phase]
traces = learner.cbstate.tracehistory[phase]
for trace in keys(traces)
val = last(last(traces, trace))
log_to(
logger.backends,
Loggables.Value(val),
string(trace),
history.steps,
group = ("Step", string(typeof(phase)), "Traces"))
end
end


"""
LogMetrics(backends...) <: Callback
Expand Down
53 changes: 53 additions & 0 deletions src/callbacks/trace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Traces(preprocess[, phase])
Record a trace during `phase` by apply each pre-processing function in
`preprocess` to the [`Learner`](#) to produce a trace value.
The trace is recorded at the end of each learning step.
See [`LogTraces`](#) for logging of the trace values.
```julia
cb = Traces((loss2 = learner -> learner.step.loss^2,
avg_gnorm = learner -> mean(map((_, g) -> norm(g), pairs(learner.step.grads))))
TrainingPhase)
```
"""
struct Traces{P<:Phase} <: Callback
preprocess::NamedTuple
end

function Traces(preprocess::NamedTuple, P::Type{<:Phase} = Phase)
return Traces{P}(preprocess)
end

stateaccess(::Traces) = (cbstate = (history = Read(), tracehistory = Write()),
step = Read())

function init!(traces::Traces, learner)
length(traces.preprocess) == length(unique(keys(traces.preprocess))) ||
error("Multiple traces have the same name!")
if !haskey(learner.cbstate, :tracehistory)
learner.cbstate.tracehistory = DefaultDict{Phase,MVHistory}(() -> MVHistory())
end
end

function on(::StepEnd, phase::P, traces::Traces{P}, learner) where {P<:Phase}
step = learner.cbstate.history[phase].steps
history = learner.cbstate.tracehistory[phase]
for (trace_name, f) in pairs(traces.preprocess)
val = f(learner)
push!(history, trace_name, step, val)
end
end

@testset "Traces" begin
cb = Traces((keya = learner -> sum(learner.step.ys),
keyb = learner -> sum(learner.step.ŷs)),
ValidationPhase)
learner = testlearner(Recorder(), cb)
@test_nowarn fit!(learner, 1)
@test :keya keys(learner.cbstate.tracehistory[ValidationPhase()])
@test :keyb keys(learner.cbstate.tracehistory[ValidationPhase()])
@test :keya keys(learner.cbstate.tracehistory[TrainingPhase()])
@test :keyb keys(learner.cbstate.tracehistory[TrainingPhase()])
end

0 comments on commit 2bb79b3

Please sign in to comment.