Skip to content

Commit

Permalink
docstring updates
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzoh committed Oct 25, 2020
1 parent 3f08a23 commit 487a1be
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 104 deletions.
10 changes: 9 additions & 1 deletion docs/callbacks/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@
- [`Scheduler`](#)
- [`ToGPU`](#)

## Utilities

There are also some callback utilities:

- [`CustomCallback`](#)
- [`throttle`](#)

## API

The following types and functions can be used to create custom callbacks. Read the [custom callbacks guide](./custom.md) for more context.

- [`Callback`](#)
- [`stateaccess`](#)
- [`runafter`](#)
- [`resolveconflict`](#)
- [`resolveconflict`](#)

2 changes: 2 additions & 0 deletions docs/tutorials/hyperparameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ schedule = Schedule(
learner = model(model, data, opt, lossfn, Scheduler(LearningRate => schedule))
```

For convenience, you can also use the [`onecycle`](#) to create the `Schedule`.

## Extending

You can create and schedule your own hyperparameters.
Expand Down
19 changes: 3 additions & 16 deletions src/FluxTraining.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
module FluxTraining

#=
Refactoring:
- better serialization of Learners
=#


using LightGraphs
using BSON: @load, @save
Expand Down Expand Up @@ -41,6 +36,7 @@ include("./callbacks/execution.jl")
include("./callbacks/logging/Loggables.jl")
include("./callbacks/logging/logger.jl")
include("./callbacks/logging/tensorboard.jl")
include("./callbacks/logging/checkpointer.jl")


# callback implementations
Expand All @@ -49,7 +45,6 @@ include("./callbacks/callbacks.jl")
include("./callbacks/custom.jl")
include("./callbacks/metrics.jl")
include("./callbacks/recorder.jl")
include("./callbacks/checkpointer.jl")

# hyperparameter scheduling
include("./callbacks/hyperparameters.jl")
Expand All @@ -62,7 +57,6 @@ include("./learner.jl")
include("./train.jl")


# TODO: remove old exports
export AbstractCallback,
Loss,
ConditionalCallback,
Expand All @@ -87,18 +81,11 @@ export AbstractCallback,
Logger,
TensorBoardBackend,
StopOnNaNLoss,

LearningRate,

throttle,
accuracy,
fit!,
loadmodel,
onecycle,
savemodel,
saveweights,
setschedule!,
splitdataset,
starttraining

loadmodel,
savemodel
end # module
45 changes: 0 additions & 45 deletions src/callbacks/checkpointer.jl

This file was deleted.

1 change: 0 additions & 1 deletion src/callbacks/graph.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

# TOOD: check for reads on :cbstate without a previous write (missing callback)
# TODO: check for cyclical dependencies
# TODO: better error messages
"""
callbackgraph(callbacks) -> SimpleDiGraph
Expand Down
7 changes: 7 additions & 0 deletions src/callbacks/hyperparameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ stateaccess(::Type{HyperParameter}) = ()

# Implementations

"""
abstract type LearningRate <: HyperParameter
Hyperparameter for the optimizer's learning rate.
See [`Scheduler`](#) and [hyperparameter scheduling](./docs/tutorials/hyperparameters.md).
"""
abstract type LearningRate <: HyperParameter{Float64} end

stateaccess(::Type{LearningRate}) = (optimizer = Write(),)
Expand Down
71 changes: 36 additions & 35 deletions src/callbacks/logging/checkpointer.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,45 @@
# CheckpointCondition
abstract type CheckpointCondition end

struct CheckpointAny <: CheckpointCondition end
mutable struct CheckpointLowest <: CheckpointCondition
lowest::Real
end
CheckpointLowest() = CheckpointLowest(Inf)

(::CheckpointAny)(loss) = true
"""
Checkpointer(folder)
Saves `learner.model` to `folder` after every [`AbstractTrainingPhase`].
function (checklowest::CheckpointLowest)(loss)
cond = loss < checklowest.lowest
if cond
checklowest.lowest = loss
Use `FluxTraining.`[`loadmodel`](#) to load a model.
"""
struct Checkpointer <: Callback
folder
function Checkpointer(folder)
mkpath(folder)
return new(folder)
end
return cond
end

# Checkpointer
struct Checkpointer <: SafeCallback
condition::CheckpointCondition
deleteprevious::Bool

stateaccess(::Checkpointer) = (
model = Read(),
cbstate = (metricsepoch = Read(), history = Read())
)

function on(::EpochEnd, phase::AbstractTrainingPhase, checkpointer::Checkpointer, learner)
loss = last(learner.cbstate.metricsepoch[phase], :Loss)[2]
epoch = learner.cbstate.history.epochs
filename = "checkpoint_epoch_$(lpad(string(epoch), 3, '0'))_loss_$loss.bson"
savemodel(learner.model, joinpath(checkpointer.folder, filename))
end
Checkpointer(condition = CheckpointLowest(); deleteprevious = false) = Checkpointer(
condition, deleteprevious)


# TODO: refactor to use `Metrics`
function on(::EpochEnd, ::ValidationPhase, checkpointer::Checkpointer, learner)
loss = epochvalue(getloss(learner.callbacks))
if checkpointer.condition(loss)
if checkpointer.deleteprevious
previousfiles = glob("model-chckpnt-E*", artifactpath(learner))
foreach(rm, previousfiles)
end
epochs = learner.cbstate.history.epochs
filename = "model-chckpnt-E$epochs-L$loss.bson"
path = joinpath(artifactpath(learner), filename)
savemodel(learner.model, path)
end


# TODO: replace with JLD2?
function savemodel(model, path)
@save path model = cpu(model)
end

stateaccess(::Checkpointer) = (callbacks = Read(), cbstate = (history = Read()))
"""
loadmodel(path)
Loads a model that was saved to `path` using `FluxTraining.`[`savemode`](#).
"""
function loadmodel(path)
@load path model
return model
end
15 changes: 15 additions & 0 deletions src/callbacks/phases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,22 @@
"""
module Phases

"""
abstract type Phase
Abstract supertype for all phases. See `subtypes(FluxTraining.Phase)`.
"""
abstract type Phase end

"""
abstract type AbstractTrainingPhase <: Phase
An abstract phase representing phases where parameter updates
are being made. This exists so callbacks can dispatch on it and work
with custom training phases.
The default implementation is [`TrainingPhase`](#).
"""
abstract type AbstractTrainingPhase <: Phase end

struct TrainingPhase <: AbstractTrainingPhase end
Expand Down
4 changes: 3 additions & 1 deletion src/callbacks/recorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
end

"""
Recorder(epoch, step, epochstats, stepstats)
Recorder()
Maintains a [`History`](#). It's stored in `learner.cbstate.history`.
"""
struct Recorder <: Callback end

Expand Down
16 changes: 11 additions & 5 deletions src/callbacks/scheduler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,20 @@ function on(::BatchBegin, ::AbstractTrainingPhase, scheduler::Scheduler, learner
end


"""
onecycle(nepochs, max_val, [start_val, end_val; start_pctg])
Creates a one-cycle [`Schedule`](#) over `nepochs` epochs from `start_val`
over `max_val` to `end_val`.
"""
function onecycle(
epochs, max_lr,
start_lr = max_lr / 10,
end_lr = max_lr / 30,
nepochs, max_val,
start_val = max_val / 10,
end_val = max_val / 30,
start_pctg = 0.1)
return Schedule(
[0, epochs * start_pctg, epochs],
[start_lr, max_lr, end_lr],
[0, nepochs * start_pctg, nepochs],
[start_val, max_val, end_val],
[Animations.sineio(), Animations.sineio()]
)

Expand Down
1 change: 1 addition & 0 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ function fit!(learner::Learner, phases::AbstractVector{<:Phase})
end

fit!(learner::Learner, phase::Phase)::Learner = fit!(learner, [phase])

"""
fit!(learner, n)
Expand Down

0 comments on commit 487a1be

Please sign in to comment.