diff --git a/Project.toml b/Project.toml index 19811f6c8..259ba6a13 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FluxTraining" uuid = "7bf95e4d-ca32-48da-9824-f0dc5310474f" authors = ["lorenzoh "] -version = "0.3.2" +version = "0.3.3" [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" diff --git a/src/FluxTraining.jl b/src/FluxTraining.jl index 3224cb6b2..19221f24b 100644 --- a/src/FluxTraining.jl +++ b/src/FluxTraining.jl @@ -55,6 +55,7 @@ include("./callbacks/earlystopping.jl") include("./callbacks/custom.jl") include("./callbacks/metrics.jl") include("./callbacks/recorder.jl") +include("./callbacks/trace.jl") include("./callbacks/sanitycheck.jl") # hyperparameter scheduling @@ -88,12 +89,14 @@ export AbstractCallback, ProgressPrinter, Metrics, MetricsPrinter, + Traces, TrainingPhase, ValidationPhase, Schedule, Scheduler, LogMetrics, SmoothLoss, + LogTraces, LogHistograms, LogHyperParams, LogVisualization, diff --git a/src/callbacks/logging/logger.jl b/src/callbacks/logging/logger.jl index 3a5bab392..d5bcd2f85 100644 --- a/src/callbacks/logging/logger.jl +++ b/src/callbacks/logging/logger.jl @@ -33,6 +33,43 @@ function log_to(backends::Tuple, loggable, name, i; group = ()) end end +""" + LogTraces(backends...) <: Callback + +Callback that logs step traces to one or more [`LoggerBackend`](#)s. + +See also [`LoggerBackend`](#), [`Loggables.Loggable`](#), [`log_to`](#), +[`TensorBoardBackend`](#) + +Example: + +```julia +logcb = LogTraces(TensorBoardBackend("tblogs")) +tracer = Traces((trace = learner -> learner.step.loss^2,), TrainingPhase) +Learner(model, lossfn; callbacks=[tracer, logcb]) +``` +""" +struct LogTraces <: Callback + backends::Tuple + LogTraces(backends...) = new(backends) +end + +stateaccess(::LogTraces) = (cbstate = (history = Read(), tracehistory = Read()),) + +function on(::StepEnd, phase, logger::LogTraces, learner) + history = learner.cbstate.history[phase] + traces = learner.cbstate.tracehistory[phase] + for trace in keys(traces) + val = last(last(traces, trace)) + log_to( + logger.backends, + Loggables.Value(val), + string(trace), + history.steps, + group = ("Step", string(typeof(phase)), "Traces")) + end +end + """ LogMetrics(backends...) <: Callback diff --git a/src/callbacks/trace.jl b/src/callbacks/trace.jl new file mode 100644 index 000000000..464f70e65 --- /dev/null +++ b/src/callbacks/trace.jl @@ -0,0 +1,53 @@ +""" + Traces(preprocess[, phase]) + +Record a trace during `phase` by apply each pre-processing function in +`preprocess` to the [`Learner`](#) to produce a trace value. +The trace is recorded at the end of each learning step. + +See [`LogTraces`](#) for logging of the trace values. +```julia +cb = Traces((loss2 = learner -> learner.step.loss^2, + avg_gnorm = learner -> mean(map((_, g) -> norm(g), pairs(learner.step.grads)))) + TrainingPhase) +``` +""" +struct Traces{P<:Phase} <: Callback + preprocess::NamedTuple +end + +function Traces(preprocess::NamedTuple, P::Type{<:Phase} = Phase) + return Traces{P}(preprocess) +end + +stateaccess(::Traces) = (cbstate = (history = Read(), tracehistory = Write()), + step = Read()) + +function init!(traces::Traces, learner) + length(traces.preprocess) == length(unique(keys(traces.preprocess))) || + error("Multiple traces have the same name!") + if !haskey(learner.cbstate, :tracehistory) + learner.cbstate.tracehistory = DefaultDict{Phase,MVHistory}(() -> MVHistory()) + end +end + +function on(::StepEnd, phase::P, traces::Traces{P}, learner) where {P<:Phase} + step = learner.cbstate.history[phase].steps + history = learner.cbstate.tracehistory[phase] + for (trace_name, f) in pairs(traces.preprocess) + val = f(learner) + push!(history, trace_name, step, val) + end +end + +@testset "Traces" begin + cb = Traces((keya = learner -> sum(learner.step.ys), + keyb = learner -> sum(learner.step.ŷs)), + ValidationPhase) + learner = testlearner(Recorder(), cb) + @test_nowarn fit!(learner, 1) + @test :keya ∈ keys(learner.cbstate.tracehistory[ValidationPhase()]) + @test :keyb ∈ keys(learner.cbstate.tracehistory[ValidationPhase()]) + @test :keya ∉ keys(learner.cbstate.tracehistory[TrainingPhase()]) + @test :keyb ∉ keys(learner.cbstate.tracehistory[TrainingPhase()]) +end