Skip to content

Latest commit

 

History

History
138 lines (109 loc) · 6.57 KB

training.md

File metadata and controls

138 lines (109 loc) · 6.57 KB

Training loop

FluxTraining.jl comes with a training loop for standard supervised learning problems, but for different tasks like self-supervised learning, being able to write custom training logic is essential. The package's training loop API requires little boilerplate to convert a regular Flux.jl training loop while making it compatible with existing callbacks.

Supervised training, step-by-step

We'll explore the API step-by-step by converting a basic training loop and then discuss ways in which more complex training loops can be implemented using the same approach. The central piece of a training loop is the logic for a single training step, and in many cases, that will be all you need to implement. Below is the definition of a basic vanilla Flux.jl training step. It takes a batch of data, calculates the loss, gradients and finally updates the parameters of the model.

function step!(model, batch, params, optimizer, lossfn)
    xs, ys = batch
    grads = gradient(params) do
        ŷs = model(xs)
        loss = lossfn(ŷs, ys)
        return loss
    end
    update!(optimizer, params, grads)
end

To make a training step work with FluxTraining.jl and its callbacks, we need to

  • store data for a step so that callbacks can access it (e.g. Metrics uses ys and ŷs to evaluate metrics for each step); and
  • dispatch events so the callbacks are triggered

We first need to create a Phase and implement a method for FluxTraining.step! that dispatches on the phase type. Phases are used to define different training behaviors using the same API and to define callback functionality that is only run during certain phases. For example, Scheduler only runs during AbstractTrainingPhases but not during ValidationPhase. Let's implement such a phase and method, moving the arguments inside a Learner in the process.

struct MyTrainingPhase <: FluxTraining.AbstractTrainingPhase

function FluxTraining.step!(learner, phase::MyTrainingPhase, batch)
    xs, ys = batch
    grads = gradient(learner.params) do
        ŷs = learner.model(xs)
        loss = learner.lossfn(ŷs, ys)
        return loss
    end
    update!(learner.optimizer, learner.params, grads)
end

Now we can already train a model using this implementation, for example using epoch!(learner, MyTrainingPhase(), dataiter). However, no callbacks would be called, since we haven't yet put in any logic that dispatches events or stores the step state. We can do both by using the helper function runstep which takes care of runnning our step logic, dispatching a StepBegin and StepEnd event before and after and handling control flow exceptions like CancelStepException. Additionally, runstep gives us a function handle which we can use to dispatch events inside the step, and state a container for storing step state. Let's use runstep and store the variables of interest inside state:

function step!(learner, phase::MyTrainingPhase, batch)
    xs, ys = batch
    runstep(learner, phase, (xs=xs, ys=ys)) do handle, state
        state.grads = gradient(learner.params) do
            state.ŷs = learner.model(state.xs)
            state.loss = learner.lossfn(state.ŷs, state.ys)
            return loss
        end
        update!(learner.optimizer, learner.params, grads)
    end
end

Now callbacks like Metrics can access variables like ys through learner.step (which is set to the last state). Finally, we can use handle to dispatch additional events:

using FluxTraining.Events: LossBegin, BackwardBegin, BackwardEnd

function step!(learner, phase::MyTrainingPhase, batch)
    xs, ys = batch
    runstep(learner, phase, (xs=xs, ys=ys)) do handle, state
        state.grads = gradient(learner.params) do
            state.ŷs = learner.model(state.xs)
            handle(LossBegin())
            state.loss = learner.lossfn(state.ŷs, state.ys)
            handle(BackwardBegin())
            return loss
        end
        handle(BackwardEnd())
        update!(learner.optimizer, learner.params, grads)
    end
end

The result is the full implementation of FluxTraining.jl's own TrainingPhase! Now we can use epoch! to train a Learner with full support for all callbacks:

for i in 1:10
    epoch!(learner, MyTrainingPhase(), dataiter)
end

Validation

The implementation of ValidationPhase is even simpler; it runs the forward pass and stores variables so that callbacks like Metrics can access them.

struct ValidationPhase <: AbstractValidationPhase end

function step!(learner, phase::ValidationPhase, batch)
    xs, ys = batch
    runstep(learner, phase, (xs=xs, ys=ys)) do _, state
        state.ŷs = learner.model(state.xs)
        state.loss = learner.lossfn(state.ŷs, state.ys)
    end
end

Epoch logic

We didn't need to implement a custom epoch! method for our phase since the default is fine here: it just iterates over every batch and calls step!. In fact, let's have a look at how epoch! is implemented:

function epoch!(learner, phase::Phase, dataiter)
    runepoch(learner, phase) do handle
        for batch in dataiter
            step!(learner, phase, batch)
        end
    end
end

Here, runepoch, similarly to runstep, takes care of epoch start/stop events and control flow. If you want more control over your training loop, you can use it to write training loops that directly use step!:

phase = MyTrainingPhase()
withepoch(learner, phase) do handle
    for batch in dataiter
        step!(learner, phase, batch)
        if learner.step.loss < 0.1
            throw(CancelFittingException("Low loss reached."))
        end
    end
end

Tips

Here are some additional tips for making it easier to implement complicated training loops.

  • You can pass (named) tuples of models to the Learner constructor. For example, for generative adversarial training, you can pass in (generator = ..., critic = ...) and then refer to them inside the step! implementation, e.g. using learner.model.generator. The models' parameters will have the same structure, i.e. learner.params.generator corresponds to params(learner.model.generator).
  • You can store any data you want in state.
  • When defining a custom phase, instead of subtyping Phase you can subtype AbstractTrainingPhase or AbstractValidationPhase so that some context-specific callbacks will work out of the box with your phase type. For example, Scheduler sets hyperparameter values only during AbstractTrainingPhase.